| @@ -102,6 +102,10 @@ int Reshape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso | |||
| auto data = reinterpret_cast<int32_t *>(shape_tensor->Data()); | |||
| CalShape<int32_t>(data, inputs_, &out_shape, shape_size); | |||
| } break; | |||
| case kNumberTypeInt64: { | |||
| auto data = reinterpret_cast<int64_t *>(shape_tensor->Data()); | |||
| CalShape<int64_t>(data, inputs_, &out_shape, shape_size); | |||
| } break; | |||
| case kNumberTypeFloat: { | |||
| auto data = reinterpret_cast<float *>(shape_tensor->Data()); | |||
| CalShape<float>(data, inputs_, &out_shape, shape_size); | |||
| @@ -223,7 +223,6 @@ if(BUILD_CONVERTER) | |||
| ${LITE_DIR}/tools/converter/graphdef_transform.cc | |||
| ${LITE_DIR}/tools/converter/converter_flags.cc | |||
| ${LITE_DIR}/tools/converter/converter.cc | |||
| ${LITE_DIR}/tools/converter/parser/onnx/onnx.pb.cc | |||
| ${LITE_DIR}/test/st/converter_test.cc | |||
| ${LITE_DIR}/test/ut/tools/optimizer/fusion/conv_activation_fusion_test.cc | |||
| ${LITE_DIR}/test/ut/tools/optimizer/fusion/conv_biasadd_fusion_test.cc | |||
| @@ -351,6 +350,7 @@ if (BUILD_CONVERTER) | |||
| anf_importer_mid | |||
| tflite_parser_mid | |||
| caffe_parser_mid | |||
| onnx_parser_mid | |||
| node_mid | |||
| graph_pass_mid | |||
| fusion_mid | |||
| @@ -71,7 +71,6 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../common/flag_parser.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../common/storage.cc | |||
| # ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/anf_exporter/anf_exporter.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/parser/onnx/onnx.pb.cc | |||
| ../optimizer/common/node_pass_extends.cc | |||
| ../optimizer/common/pass_manager_extends.cc | |||
| @@ -86,6 +85,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| add_subdirectory(parser/caffe) | |||
| add_subdirectory(parser/tflite) | |||
| add_subdirectory(parser/onnx) | |||
| add_subdirectory(legacy_optimizer) | |||
| add_subdirectory(quantizer) | |||
| @@ -98,6 +98,7 @@ add_executable(converter_lite | |||
| target_link_libraries(converter_lite PRIVATE | |||
| tflite_parser_mid | |||
| caffe_parser_mid | |||
| onnx_parser_mid | |||
| anf_importer_mid | |||
| node_mid | |||
| graph_pass_mid | |||
| @@ -27,6 +27,7 @@ | |||
| #include "tools/common/storage.h" | |||
| #include "parser/caffe/caffe_converter.h" | |||
| #include "parser/tflite/tflite_converter.h" | |||
| #include "parser/onnx/onnx_converter.h" | |||
| #include "src/common/anf_exporter/anf_exporter.h" | |||
| #include "src/common/anf_importer/import_from_protobuf.h" | |||
| #include "tools/converter/parser/onnx/onnx.pb.h" | |||
| @@ -185,6 +186,10 @@ int RunConverter(int argc, const char **argv) { | |||
| TfliteConverter tfLiteConverter; | |||
| fb_graph = tfLiteConverter.Convert(flags); | |||
| } break; | |||
| case FmkType::FmkType_ONNX: { | |||
| OnnxConverter onnxConverter; | |||
| fb_graph = onnxConverter.Convert(flags); | |||
| } break; | |||
| default: { | |||
| MS_LOG(ERROR) << "Unsupported fmkType: " << flags->fmk; | |||
| return 1; | |||
| @@ -14,13 +14,11 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/converter_flags.h" | |||
| #include <regex> | |||
| #include <string> | |||
| #include "ir/dtype/type_id.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| namespace converter { | |||
| @@ -89,8 +87,10 @@ int Flags::Init(int argc, const char **argv) { | |||
| this->fmk = FmkType_MS; | |||
| } else if (this->fmkIn == "TFLITE") { | |||
| this->fmk = FmkType_TFLITE; | |||
| } else if (this->fmkIn == "ONNX") { | |||
| this->fmk = FmkType_ONNX; | |||
| } else { | |||
| std::cerr << "INPUT ILLEGAL: fmk must be TFLITE|CAFFE|MS"; | |||
| std::cerr << "INPUT ILLEGAL: fmk must be TFLITE|CAFFE|MS|ONNX"; | |||
| return 1; | |||
| } | |||
| @@ -138,6 +138,12 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) { | |||
| } | |||
| beforeNodeType = kNCHW2NHWC; | |||
| afterNodeType = kNHWC2NCHW; | |||
| } else if (fmkType == converter::FmkType_ONNX) { | |||
| if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) { | |||
| continue; | |||
| } | |||
| beforeNodeType = kNCHW2NHWC; | |||
| afterNodeType = kNHWC2NCHW; | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported fmk: " << fmkType; | |||
| return RET_ERROR; | |||
| @@ -197,4 +203,3 @@ void FormatTransPass::SetFmk(converter::FmkType fmkType) { this->fmkType = fmkTy | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -189,7 +189,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { | |||
| if (opType == schema::PrimitiveType_Conv2D) { | |||
| weightTensor->format = schema::Format_KCHW; | |||
| } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { | |||
| weightTensor->format = schema::Format_CKHW; | |||
| weightTensor->format = schema::Format_KCHW; | |||
| } else if (opType == schema::PrimitiveType_DeConv2D) { | |||
| weightTensor->format = schema::Format_CKHW; | |||
| } else { | |||
| @@ -15,14 +15,15 @@ | |||
| */ | |||
| #include <memory> | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_argmax_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxArgMaxParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| unique_ptr<schema::ArgMaxT> attr(new schema::ArgMaxT()); | |||
| MS_LOG(DEBUG) << "onnx ArgMaxParser"; | |||
| std::unique_ptr<schema::ArgMaxT> attr(new schema::ArgMaxT()); | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "axis") { | |||
| @@ -17,8 +17,8 @@ | |||
| #ifndef MS_ONNX_ARGMAX_PARSER_H | |||
| #define MS_ONNX_ARGMAX_PARSER_H | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -15,111 +15,118 @@ | |||
| */ | |||
| #include <memory> | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxAddParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "onnx AddParser"; | |||
| if (op != nullptr) { | |||
| std::unique_ptr<schema::AddT> attr(new schema::AddT()); | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| op->primitive->value.type = schema::PrimitiveType_Add; | |||
| op->primitive->value.value = nullptr; | |||
| op->primitive->value.value = attr.release(); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxSubParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "onnx SubParser"; | |||
| if (op != nullptr) { | |||
| std::unique_ptr<schema::SubT> attr(new schema::SubT()); | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| op->primitive->value.type = schema::PrimitiveType_Sub; | |||
| op->primitive->value.value = nullptr; | |||
| op->primitive->value.value = attr.release(); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxMulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "onnx MulParser"; | |||
| if (op != nullptr) { | |||
| std::unique_ptr<schema::MulT> attr(new schema::MulT()); | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| op->primitive->value.type = schema::PrimitiveType_Mul; | |||
| op->primitive->value.value = nullptr; | |||
| op->primitive->value.value = attr.release(); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxDivParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "onnx DivParser"; | |||
| if (op != nullptr) { | |||
| std::unique_ptr<schema::DivT> attr(new schema::DivT()); | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| op->primitive->value.type = schema::PrimitiveType_RealDiv; | |||
| op->primitive->value.value = nullptr; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxMeanParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| if (op != nullptr) { | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| op->primitive->value.type = schema::PrimitiveType_Mean; | |||
| op->primitive->value.value = nullptr; | |||
| op->primitive->value.type = schema::PrimitiveType_Div; | |||
| op->primitive->value.value = attr.release(); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxPowParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "onnx PowParser"; | |||
| if (op != nullptr) { | |||
| // TODO(wangzhe) attr power need populate | |||
| std::unique_ptr<schema::PowerT> attr(new schema::PowerT()); | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| op->primitive->value.type = schema::PrimitiveType_Power; | |||
| op->primitive->value.value = nullptr; | |||
| op->primitive->value.value = attr.release(); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxEqualParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node, | |||
| STATUS OnnxEqualParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "onnx EqualParser"; | |||
| if (op != nullptr) { | |||
| std::unique_ptr<schema::EqualT> attr(new schema::EqualT()); | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| op->primitive->value.type = schema::PrimitiveType_Equal; | |||
| op->primitive->value.value = nullptr; | |||
| op->primitive->value.value = attr.release(); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxLessParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "onnx LessParser"; | |||
| if (op != nullptr) { | |||
| std::unique_ptr<schema::LessT> attr(new schema::LessT()); | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| op->primitive->value.type = schema::PrimitiveType_Less; | |||
| op->primitive->value.value = nullptr; | |||
| op->primitive->value.value = attr.release(); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxGreaterParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node, | |||
| STATUS OnnxGreaterParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "onnx GreaterParser"; | |||
| if (op != nullptr) { | |||
| std::unique_ptr<schema::GreaterT> attr(new schema::GreaterT()); | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| op->primitive->value.type = schema::PrimitiveType_Greater; | |||
| op->primitive->value.value = nullptr; | |||
| op->primitive->value.value = attr.release(); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxMinParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "onnx MinParser"; | |||
| if (op != nullptr) { | |||
| std::unique_ptr<schema::MinT> attr(new schema::MinT()); | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| op->primitive->value.type = schema::PrimitiveType_Min; | |||
| op->primitive->value.value = nullptr; | |||
| op->primitive->value.value = attr.release(); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxEltwiseParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node, | |||
| STATUS OnnxEltwiseParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "onnx EltwiseParser"; | |||
| std::unique_ptr<schema::EltwiseT> attr(new schema::EltwiseT()); | |||
| if (onnx_node.op_type() == "Prod") { | |||
| attr->mode = schema::EltwiseMode_PROD; | |||
| } else if (onnx_node.op_type() == "Sum") { | |||
| // there is no Prod in onnx | |||
| if (onnx_node.op_type() == "Sum") { | |||
| attr->mode = schema::EltwiseMode_SUM; | |||
| } else if (onnx_node.op_type() == "Maximum") { | |||
| } else if (onnx_node.op_type() == "Max") { | |||
| attr->mode = schema::EltwiseMode_MAXIMUM; | |||
| } | |||
| @@ -131,109 +138,133 @@ STATUS OnnxEltwiseParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxFloorParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node, | |||
| STATUS OnnxFloorParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "onnx FloorParser"; | |||
| if (op != nullptr) { | |||
| std::unique_ptr<schema::FloorT> attr(new schema::FloorT()); | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| op->primitive->value.type = schema::PrimitiveType_Floor; | |||
| op->primitive->value.value = nullptr; | |||
| op->primitive->value.value = attr.release(); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxAbsParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "onnx AbsParser"; | |||
| if (op != nullptr) { | |||
| std::unique_ptr<schema::AbsT> attr(new schema::AbsT()); | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| op->primitive->value.type = schema::PrimitiveType_Abs; | |||
| op->primitive->value.value = nullptr; | |||
| op->primitive->value.value = attr.release(); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxNegParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "onnx NegParser"; | |||
| if (op != nullptr) { | |||
| std::unique_ptr<schema::NegT> attr(new schema::NegT()); | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| op->primitive->value.type = schema::PrimitiveType_Neg; | |||
| op->primitive->value.value = nullptr; | |||
| op->primitive->value.value = attr.release(); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxExpParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "onnx ExpParser"; | |||
| if (op != nullptr) { | |||
| std::unique_ptr<schema::ExpT> attr(new schema::ExpT()); | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| op->primitive->value.type = schema::PrimitiveType_Exp; | |||
| op->primitive->value.value = nullptr; | |||
| op->primitive->value.value = attr.release(); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxCosParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "onnx CosParser"; | |||
| if (op != nullptr) { | |||
| std::unique_ptr<schema::CosT> attr(new schema::CosT()); | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| op->primitive->value.type = schema::PrimitiveType_Cos; | |||
| op->primitive->value.value = nullptr; | |||
| op->primitive->value.value = attr.release(); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxSinParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "onnx SinParser"; | |||
| if (op != nullptr) { | |||
| std::unique_ptr<schema::SinT> attr(new schema::SinT()); | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| op->primitive->value.type = schema::PrimitiveType_Sin; | |||
| op->primitive->value.value = nullptr; | |||
| op->primitive->value.value = attr.release(); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxSqrtParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "onnx SqrtParser"; | |||
| if (op != nullptr) { | |||
| std::unique_ptr<schema::SqrtT> attr(new schema::SqrtT()); | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| op->primitive->value.type = schema::PrimitiveType_Sqrt; | |||
| op->primitive->value.value = nullptr; | |||
| op->primitive->value.value = attr.release(); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxCeilParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "onnx CeilParser"; | |||
| if (op != nullptr) { | |||
| std::unique_ptr<schema::CeilT> attr(new schema::CeilT()); | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| op->primitive->value.type = schema::PrimitiveType_Ceil; | |||
| op->primitive->value.value = nullptr; | |||
| op->primitive->value.value = attr.release(); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxLogParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "onnx LogParser"; | |||
| if (op != nullptr) { | |||
| std::unique_ptr<schema::LogT> attr(new schema::LogT()); | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| op->primitive->value.type = schema::PrimitiveType_Log; | |||
| op->primitive->value.value = nullptr; | |||
| op->primitive->value.value = attr.release(); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxTanParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "onnx TanParser"; | |||
| if (op != nullptr) { | |||
| std::unique_ptr<schema::TanT> attr(new schema::TanT()); | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| op->primitive->value.type = schema::PrimitiveType_Tan; | |||
| op->primitive->value.value = nullptr; | |||
| op->primitive->value.value = attr.release(); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxAtanParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "onnx AtanParser"; | |||
| if (op != nullptr) { | |||
| std::unique_ptr<schema::AtanT> attr(new schema::AtanT()); | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| op->primitive->value.type = schema::PrimitiveType_Atan; | |||
| op->primitive->value.value = nullptr; | |||
| op->primitive->value.value = attr.release(); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxAsinParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "onnx AsinParser"; | |||
| if (op != nullptr) { | |||
| std::unique_ptr<schema::AsinT> attr(new schema::AsinT()); | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| op->primitive->value.type = schema::PrimitiveType_Asin; | |||
| op->primitive->value.value = nullptr; | |||
| op->primitive->value.value = attr.release(); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxTanhParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "onnx TanhParser"; | |||
| if (op != nullptr) { | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| op->primitive->value.value = nullptr; | |||
| MS_LOG(ERROR) << "mslite don't support tanh now"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -243,13 +274,12 @@ OnnxNodeRegistrar g_onnxInt8AddParser("Int8Add", new OnnxAddParser()); | |||
| OnnxNodeRegistrar g_onnxSubParser("Sub", new OnnxSubParser()); | |||
| OnnxNodeRegistrar g_onnxMulParser("Mul", new OnnxMulParser()); | |||
| OnnxNodeRegistrar g_onnxDivParser("Div", new OnnxDivParser()); | |||
| OnnxNodeRegistrar g_onnxMeanParser("Mean", new OnnxMeanParser()); | |||
| // OnnxNodeRegistrar g_onnxMeanParser("Mean", new OnnxMeanParser()); // onnx's Mean is different from mslite's | |||
| OnnxNodeRegistrar g_onnxPowParser("Power", new OnnxPowParser()); | |||
| OnnxNodeRegistrar g_onnxEqualParser("Equal", new OnnxEqualParser()); | |||
| OnnxNodeRegistrar g_onnxLessParser("Less", new OnnxLessParser()); | |||
| OnnxNodeRegistrar g_onnxGreaterParser("Greater", new OnnxGreaterParser()); | |||
| OnnxNodeRegistrar g_onnxMinParser("Min", new OnnxMinParser()); | |||
| OnnxNodeRegistrar g_onnxProdParser("Prod", new OnnxEltwiseParser()); | |||
| OnnxNodeRegistrar g_onnxSumParser("Sum", new OnnxEltwiseParser()); | |||
| OnnxNodeRegistrar g_onnxMaxParser("Max", new OnnxEltwiseParser()); | |||
| OnnxNodeRegistrar g_onnxFloorParser("Floor", new OnnxFloorParser()); | |||
| @@ -267,4 +297,3 @@ OnnxNodeRegistrar g_onnxAsinParser("Asin", new OnnxAsinParser()); | |||
| OnnxNodeRegistrar g_onnxTanhParser("Tanh", new OnnxTanhParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -17,8 +17,8 @@ | |||
| #ifndef MS_ONNX_ARITHMETIC_OPREATION_PARSER_H | |||
| #define MS_ONNX_ARITHMETIC_OPREATION_PARSER_H | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -14,14 +14,15 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_batchnorm_parser.h" | |||
| #include <memory> | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxBatchNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| unique_ptr<schema::FusedBatchNormT> attr(new schema::FusedBatchNormT()); | |||
| MS_LOG(DEBUG) << "onnx BatchNormParser"; | |||
| std::unique_ptr<schema::FusedBatchNormT> attr(new schema::FusedBatchNormT()); | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| if (onnx_node_attr.name() == "epsilon") { | |||
| attr->epsilon = onnx_node_attr.f(); | |||
| @@ -17,8 +17,8 @@ | |||
| #ifndef MS_ONNX_ADD_PARSER_H | |||
| #define MS_ONNX_ADD_PARSER_H | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -15,7 +15,7 @@ | |||
| */ | |||
| #include <memory> | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_biasadd_parser.h" | |||
| // using namespace mindspore::predict; | |||
| // using namespace onnx; | |||
| @@ -25,7 +25,8 @@ namespace lite { | |||
| STATUS OnnxBiasAddParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| unique_ptr<schema::BiasAddT> attr(new schema::BiasAddT()); | |||
| MS_LOG(DEBUG) << "onnx BiasAddParser"; | |||
| std::unique_ptr<schema::BiasAddT> attr(new schema::BiasAddT()); | |||
| // use channel dim as axis | |||
| attr->axis = {1}; | |||
| if (op != nullptr) { | |||
| @@ -17,8 +17,8 @@ | |||
| #ifndef MS_ONNX_BIASADD_PARSER_H | |||
| #define MS_ONNX_BIASADD_PARSER_H | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -15,12 +15,13 @@ | |||
| */ | |||
| #include <memory> | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_cast_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| unique_ptr<schema::CastT> attr(new schema::CastT()); | |||
| MS_LOG(DEBUG) << "onnx CastParser"; | |||
| std::unique_ptr<schema::CastT> attr(new schema::CastT()); | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "to") { | |||
| @@ -17,8 +17,8 @@ | |||
| #ifndef MS_ONNX_CAST_PARSER_H | |||
| #define MS_ONNX_CAST_PARSER_H | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -15,24 +15,32 @@ | |||
| */ | |||
| #include <memory> | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_clip_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| unique_ptr<schema::ClipT> attr(new schema::ClipT()); | |||
| MS_LOG(DEBUG) << "onnx ClipParser"; | |||
| float min = -1, max = -1; | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "max") { | |||
| attr->max = onnx_node_attr.f(); | |||
| max = onnx_node_attr.f(); | |||
| } else if (attribute_name == "min") { | |||
| attr->min = onnx_node_attr.f(); | |||
| min = onnx_node_attr.f(); | |||
| } | |||
| } | |||
| if (op != nullptr) { | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| op->primitive->value.type = schema::PrimitiveType_Clip; | |||
| op->primitive->value.value = attr.release(); | |||
| if (min == 0 && max == 6) { | |||
| std::unique_ptr<schema::ActivationT> attr(new schema::ActivationT()); | |||
| attr->type = schema::ActivationType_RELU6; | |||
| if (op != nullptr) { | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| op->primitive->value.type = schema::PrimitiveType_Activation; | |||
| op->primitive->value.value = attr.release(); | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "only support convert clip(0,6) to relu6, other value is not supported"; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -40,4 +48,3 @@ STATUS OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||
| OnnxNodeRegistrar g_onnxClipParser("Clip", new OnnxClipParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -17,8 +17,8 @@ | |||
| #ifndef MS_ONNX_CLIP_PARSER_H | |||
| #define MS_ONNX_CLIP_PARSER_H | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -15,14 +15,15 @@ | |||
| */ | |||
| #include <memory> | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_concat_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxConcatParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| unique_ptr<schema::ConcatT> attr(new schema::ConcatT()); | |||
| MS_LOG(DEBUG) << "onnx ConcatParser"; | |||
| std::unique_ptr<schema::ConcatT> attr(new schema::ConcatT()); | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "axis") { | |||
| @@ -17,8 +17,8 @@ | |||
| #ifndef MS_ONNX_CONCAT_PARSER_H | |||
| #define MS_ONNX_CONCAT_PARSER_H | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -15,17 +15,19 @@ | |||
| */ | |||
| #include <memory> | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_constant_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxConstantParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "onnx ConstantParser"; | |||
| if (op != nullptr) { | |||
| std::unique_ptr<schema::ConstantT> attr(new schema::ConstantT()); | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| op->primitive->value.type = schema::PrimitiveType_Constant; | |||
| op->primitive->value.value = nullptr; | |||
| op->primitive->value.value = attr.release(); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -17,8 +17,8 @@ | |||
| #ifndef MS_ONNX_CONSTANT_PARSER_H | |||
| #define MS_ONNX_CONSTANT_PARSER_H | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -17,17 +17,18 @@ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <algorithm> | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_conv_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| bool OnnxConvParser::ParseGroupConvolution(schema::CNodeT *op, schema::Conv2DT *attr) { | |||
| MS_LOG(DEBUG) << "onnx DepthwiseConvParser"; | |||
| if (attr == nullptr || attr->group != attr->channelIn) { | |||
| return false; | |||
| } | |||
| std::unique_ptr<schema::DepthwiseConv2DT> depthwiseConv2DParam(new (std::nothrow) schema::DepthwiseConv2DT()); | |||
| if (depthwiseConv2DParam == nullptr) { | |||
| // MS_LOGW("new DepthwiseConv2DT failed"); | |||
| MS_LOG(ERROR) << "new DepthwiseConv2DT failed"; | |||
| return false; | |||
| } | |||
| depthwiseConv2DParam->format = attr->format; | |||
| @@ -48,12 +49,12 @@ bool OnnxConvParser::ParseGroupConvolution(schema::CNodeT *op, schema::Conv2DT * | |||
| depthwiseConv2DParam->activationType = attr->activationType; | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; | |||
| delete (op->primitive->value.value); | |||
| op->primitive->value.value = depthwiseConv2DParam.release(); | |||
| return true; | |||
| } | |||
| STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "onnx ConvParser"; | |||
| auto attr = new schema::Conv2DT(); | |||
| // set opdef each attr params | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| @@ -61,30 +62,32 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||
| attr->group = static_cast<int32_t>(onnx_node_attr.i()); | |||
| } else if (onnx_node_attr.name() == "dilations") { | |||
| if (onnx_node_attr.ints().size() != 2) { | |||
| // MS_LOGE("dilations size %d is not 2", onnx_node_attr.ints().size()); | |||
| MS_LOG(ERROR) << "dilations size " << onnx_node_attr.ints().size() << " is not 2"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->dilateW = static_cast<int32_t>(onnx_node_attr.ints(0)); | |||
| attr->dilateH = static_cast<int32_t>(onnx_node_attr.ints(1)); | |||
| // TODO(wangzhe) verify the change | |||
| attr->dilateH = static_cast<int32_t>(onnx_node_attr.ints(0)); | |||
| attr->dilateW = static_cast<int32_t>(onnx_node_attr.ints(1)); | |||
| } else if (onnx_node_attr.name() == "kernels") { | |||
| if (onnx_node_attr.ints().size() != 2) { | |||
| // MS_LOGE("kernel_shape size %d is not 2", onnx_node_attr.ints().size()); | |||
| MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0)); | |||
| attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1)); | |||
| } else if (onnx_node_attr.name() == "kernel_shape") { | |||
| if (onnx_node_attr.ints().size() != 2) { | |||
| // MS_LOGE("kernel_shape size %d is not 2", onnx_node_attr.ints().size()); | |||
| MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(0)); | |||
| attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(1)); | |||
| // TODO(wangzhe) verify the change | |||
| attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0)); | |||
| attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1)); | |||
| } else if (onnx_node_attr.name() == "auto_pad") { | |||
| attr->padMode = GetOnnxPadMode(onnx_node_attr); | |||
| } else if (onnx_node_attr.name() == "pads") { | |||
| if (onnx_node_attr.ints().size() != 4) { | |||
| // MS_LOGE("pads size %d is not 4", onnx_node_attr.ints().size()); | |||
| MS_LOG(ERROR) << "pads size " << onnx_node_attr.ints().size() << " is not 4"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->padUp = static_cast<int32_t>(onnx_node_attr.ints(0)); | |||
| @@ -93,16 +96,17 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||
| attr->padRight = static_cast<int32_t>(onnx_node_attr.ints(3)); | |||
| } else if (onnx_node_attr.name() == "strides") { | |||
| if (onnx_node_attr.ints().size() != 2) { | |||
| // MS_LOGE("strides size %d is not 2", onnx_node_attr.ints().size()); | |||
| MS_LOG(ERROR) << "strides size " << onnx_node_attr.ints().size() << " is not 2"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(0)); | |||
| attr->strideH = static_cast<int32_t>(onnx_node_attr.ints(1)); | |||
| // TODO(wangzhe) verify the change | |||
| attr->strideH = static_cast<int32_t>(onnx_node_attr.ints(0)); | |||
| attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(1)); | |||
| } else if (onnx_node_attr.name() == "order") { | |||
| if (onnx_node_attr.s() == "NHWC") { | |||
| attr->format = schema::Format_NHWC; | |||
| } else { | |||
| // MS_LOGE("Unsupported format: %s", onnx_node_attr.s().c_str()); | |||
| MS_LOG(ERROR) << "Unsupported format: " << onnx_node_attr.s(); | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| @@ -114,7 +118,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||
| std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(), | |||
| [onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; }); | |||
| if (nodeIter == onnx_graph.initializer().end()) { | |||
| // MS_LOGE("not find node: %s", onnx_conv_weight.c_str()) | |||
| MS_LOG(ERROR) << "not find node: " << onnx_conv_weight; | |||
| return RET_ERROR; | |||
| } | |||
| std::vector<int> weight_shape; | |||
| @@ -129,7 +133,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||
| std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(), | |||
| [onnx_conv_weight](const onnx::NodeProto &proto) { return proto.output(0) == onnx_conv_weight; }); | |||
| if (nodeIter == onnx_graph.node().end()) { | |||
| // MS_LOGE("can not find node: %s", onnx_conv_weight.c_str()) | |||
| MS_LOG(ERROR) << "can not find node: " << onnx_conv_weight; | |||
| return RET_ERROR; | |||
| } | |||
| std::vector<int> dims; | |||
| @@ -139,6 +143,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||
| dims.insert(dims.begin(), iter->ints().begin(), iter->ints().end()); | |||
| } | |||
| attr->channelOut = dims[0]; | |||
| // TODO(wangzhe) verify this code | |||
| attr->channelIn = dims[3] * attr->group; | |||
| } | |||
| attr->format = schema::Format_NCHW; | |||
| @@ -156,7 +161,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||
| if (attr->group != 1) { | |||
| if (!ParseGroupConvolution(op, attr)) { | |||
| delete attr; | |||
| // MS_LOGE("Convert Convolution to Depthwise failed"); | |||
| MS_LOG(ERROR) << "Convert Convolution to Depthwise failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| @@ -169,4 +174,3 @@ OnnxNodeRegistrar g_onnxConvReluParser("ConvRelu", new OnnxConvParser()); | |||
| OnnxNodeRegistrar g_onnxInt8ConvReluParser("Int8ConvRelu", new OnnxConvParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -17,8 +17,8 @@ | |||
| #ifndef MS_ONNX_CONV_PARSER_H | |||
| #define MS_ONNX_CONV_PARSER_H | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -14,7 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_converter.h" | |||
| #include "tools/converter/parser/onnx/onnx_converter.h" | |||
| #include "tools/converter/parser/onnx/onnx_model_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -18,9 +18,8 @@ | |||
| #define MS_ONNX_CONVERTER_H | |||
| #include <string> | |||
| #include <memory> | |||
| #include "mindspore/lite/tools/converter/converter.h" | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h" | |||
| #include "mindspore/lite/tools/converter/graphdef_transform.h" | |||
| #include "tools/converter/converter.h" | |||
| #include "tools/converter/graphdef_transform.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -17,11 +17,12 @@ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <algorithm> | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_deconv_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| bool OnnxDeConvParser::ParseGroupDeConvolution(schema::CNodeT *op, schema::DeConv2DT *attr) { | |||
| MS_LOG(DEBUG) << "onnx DeConvParser"; | |||
| if (attr == nullptr || attr->group != attr->channelOut) { | |||
| return false; | |||
| } | |||
| @@ -17,8 +17,8 @@ | |||
| #ifndef MS_ONNX_DECONV_PARSER_H | |||
| #define MS_ONNX_DECONV_PARSER_H | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -15,14 +15,15 @@ | |||
| */ | |||
| #include <memory> | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_depth_to_space_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxDepthToSpaceParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| unique_ptr<schema::DepthToSpaceT> attr(new schema::DepthToSpaceT()); | |||
| MS_LOG(DEBUG) << "onnx DepthToSpaceParser"; | |||
| std::unique_ptr<schema::DepthToSpaceT> attr(new schema::DepthToSpaceT()); | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto& attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "blocksize") { | |||
| @@ -17,8 +17,8 @@ | |||
| #ifndef MS_ONNX_DEPTH_TO_SPACE_PARSER_H | |||
| #define MS_ONNX_DEPTH_TO_SPACE_PARSER_H | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -15,14 +15,15 @@ | |||
| */ | |||
| #include <memory> | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_dropout_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxDropoutParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| unique_ptr<schema::DropoutT> attr(new schema::DropoutT()); | |||
| MS_LOG(DEBUG) << "onnx DropoutParser"; | |||
| std::unique_ptr<schema::DropoutT> attr(new schema::DropoutT()); | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "ratio") { | |||
| @@ -17,8 +17,8 @@ | |||
| #ifndef MS_ONNX_ARGMAX_PARSER_H | |||
| #define MS_ONNX_ARGMAX_PARSER_H | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -15,12 +15,13 @@ | |||
| */ | |||
| #include <memory> | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_elu_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxEluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| unique_ptr<schema::EluT> attr(new schema::EluT()); | |||
| MS_LOG(DEBUG) << "onnx EluParser"; | |||
| std::unique_ptr<schema::EluT> attr(new schema::EluT()); | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto& attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "alpha") { | |||
| @@ -17,8 +17,8 @@ | |||
| #ifndef MS_ONNX_ELU_PARSER_H | |||
| #define MS_ONNX_ELU_PARSER_H | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -15,17 +15,18 @@ | |||
| */ | |||
| #include <memory> | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_expand_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node, | |||
| STATUS OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "onnx ExpandParser"; | |||
| if (op != nullptr) { | |||
| std::unique_ptr<schema::BroadcastT> attr(new schema::BroadcastT()); | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| op->primitive->value.type = schema::PrimitiveType_Broadcast; | |||
| op->primitive->value.value = nullptr; | |||
| op->primitive->value.value = attr.release(); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -33,4 +34,3 @@ STATUS OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| OnnxNodeRegistrar g_onnxExpandSpaceParser("Expand", new OnnxExpandParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -17,8 +17,8 @@ | |||
| #ifndef MS_ONNX_EXPAND_PARSER_H | |||
| #define MS_ONNX_EXPAND_PARSER_H | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -15,14 +15,15 @@ | |||
| */ | |||
| #include <memory> | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_flatten_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxFlattenParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| unique_ptr<schema::ReshapeT> attr(new schema::ReshapeT()); | |||
| MS_LOG(DEBUG) << "onnx FlattenParser"; | |||
| std::unique_ptr<schema::ReshapeT> attr(new schema::ReshapeT()); | |||
| int axis = 1; | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| @@ -17,8 +17,8 @@ | |||
| #ifndef MS_ONNX_FLATTEN_PARSER_H | |||
| #define MS_ONNX_FLATTEN_PARSER_H | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -15,14 +15,15 @@ | |||
| */ | |||
| #include <memory> | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_gather_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxGatherParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| unique_ptr<schema::GatherT> attr(new schema::GatherT()); | |||
| MS_LOG(DEBUG) << "onnx GatherParser"; | |||
| std::unique_ptr<schema::GatherT> attr(new schema::GatherT()); | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto& attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "axis") { | |||
| @@ -17,8 +17,8 @@ | |||
| #ifndef MS_ONNX_GATHER_PARSER_H | |||
| #define MS_ONNX_GATHER_PARSER_H | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -15,12 +15,13 @@ | |||
| */ | |||
| #include <memory> | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_lrn_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| unique_ptr<schema::LrnT> attr(new schema::LrnT()); | |||
| MS_LOG(DEBUG) << "onnx LrnParser"; | |||
| std::unique_ptr<schema::LrnT> attr(new schema::LrnT()); | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto& attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "size") { | |||
| @@ -17,8 +17,8 @@ | |||
| #ifndef MS_ONNX_LRN_PARSER_H | |||
| #define MS_ONNX_LRN_PARSER_H | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -15,14 +15,14 @@ | |||
| */ | |||
| #include <memory> | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_matmul_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node, | |||
| STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| unique_ptr<schema::MatMulT> attr(new schema::MatMulT()); | |||
| MS_LOG(DEBUG) << "onnx MatMulParser"; | |||
| std::unique_ptr<schema::MatMulT> attr(new schema::MatMulT()); | |||
| float alpha = 1.0f; | |||
| float beta = 1.0f; | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| @@ -38,7 +38,7 @@ STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| } | |||
| } | |||
| if (alpha != 1 || beta != 1) { | |||
| // MS_LOGE("not support alpha * A * B + beta * C"); | |||
| MS_LOG(ERROR) << "not support alpha * A * B + beta * C"; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| @@ -53,4 +53,3 @@ STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| OnnxNodeRegistrar g_onnxMatmulParser("MatMul", new OnnxMatmulParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -17,8 +17,8 @@ | |||
| #ifndef MS_ONNX_MATMUL_PARSER_H | |||
| #define MS_ONNX_MATMUL_PARSER_H | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -18,7 +18,7 @@ | |||
| #include <unordered_map> | |||
| #include <algorithm> | |||
| #include <utility> | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_model_parser.h" | |||
| #include "tools/common/graph_util.h" | |||
| #include "src/common/utils.h" | |||
| @@ -35,11 +35,12 @@ static const std::unordered_map<int, mindspore::TypeId> TYPE_MAP = { | |||
| {onnx::TensorProto_DataType_UINT32, mindspore::kNumberTypeUInt32}, | |||
| {onnx::TensorProto_DataType_INT64, mindspore::kNumberTypeInt64}, | |||
| {onnx::TensorProto_DataType_FLOAT16, mindspore::kNumberTypeFloat16}, | |||
| {onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat}}; | |||
| {onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32}}; | |||
| TypeId OnnxModelParser::GetDateTypeFromOnnx(onnx::TensorProto_DataType onnx_type) { | |||
| TypeId OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type) { | |||
| auto iter = TYPE_MAP.find(onnx_type); | |||
| if (iter == TYPE_MAP.end()) { | |||
| MS_LOG(ERROR) << "unsupported onnx data type: " << onnx_type; | |||
| return kTypeUnknown; | |||
| } | |||
| return iter->second; | |||
| @@ -56,7 +57,7 @@ std::vector<int32_t> OnnxModelParser::GetDimsFromOnnxValue(const onnx::ValueInfo | |||
| STATUS OnnxModelParser::ReadOnnxModelFromBinary(const std::string &modelFile, google::protobuf::Message *onnx_model) { | |||
| std::unique_ptr<char> onnx_file(new (std::nothrow) char[PATH_MAX]{0}); | |||
| if (realpath(modelFile.c_str(), onnx_file.get()) == nullptr) { | |||
| // MS_LOGE("get realpath %s fail", modelFile.c_str()); | |||
| MS_LOG(ERROR) << "get realpath " << modelFile << " fail"; | |||
| return RET_ERROR; | |||
| } | |||
| int fd = open(onnx_file.get(), O_RDONLY); | |||
| @@ -65,7 +66,7 @@ STATUS OnnxModelParser::ReadOnnxModelFromBinary(const std::string &modelFile, go | |||
| code_input.SetTotalBytesLimit(INT_MAX, 536870912); | |||
| bool ret = onnx_model->ParseFromCodedStream(&code_input); | |||
| if (!ret) { | |||
| // MS_LOGE("load onnx file failed"); | |||
| MS_LOG(ERROR) << "load onnx file failed"; | |||
| return RET_ERROR; | |||
| } | |||
| (void)close(fd); | |||
| @@ -73,46 +74,47 @@ STATUS OnnxModelParser::ReadOnnxModelFromBinary(const std::string &modelFile, go | |||
| } | |||
| STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache) { | |||
| // MS_LOGD("set onnx constant tensors"); | |||
| MS_LOG(DEBUG) << "set onnx constant tensors"; | |||
| for (const auto &onnx_const_value : onnx_graph.initializer()) { | |||
| std::vector<int32_t> dims; | |||
| std::copy(onnx_const_value.dims().begin(), onnx_const_value.dims().end(), std::back_inserter(dims)); | |||
| auto data_type = GetDateTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(onnx_const_value.data_type())); | |||
| auto data_type = GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(onnx_const_value.data_type())); | |||
| if (data_type == kTypeUnknown) { | |||
| // MS_LOGE("not support onnx type %d", static_cast<onnx::TensorProto_DataType>(onnx_const_value.data_type())); | |||
| MS_LOG(ERROR) << "not support onnx data type " | |||
| << static_cast<onnx::TensorProto_DataType>(onnx_const_value.data_type()); | |||
| return RET_ERROR; | |||
| } | |||
| std::unique_ptr<schema::TensorT> tensor(new (std::nothrow) schema::TensorT); | |||
| if (tensor == nullptr) { | |||
| // MS_LOGE("new tensor failed"); | |||
| MS_LOG(ERROR) << "new tensor failed"; | |||
| return RET_ERROR; | |||
| } | |||
| tensor->dataType = data_type; | |||
| tensor->format = schema::Format_NCHW; | |||
| for (const auto &it : dims) { | |||
| tensor->dims.emplace_back(it); | |||
| } | |||
| tensor->format = schema::Format_NCHW; // onnx use NCHW | |||
| std::copy(onnx_const_value.dims().begin(), onnx_const_value.dims().end(), std::back_inserter(tensor->dims)); | |||
| tensor->nodeType = schema::NodeType_ValueNode; | |||
| if (CopyOnnxTensorData(onnx_const_value, tensor.get())) { | |||
| MS_LOG(ERROR) << "copy onnx data failed"; | |||
| return RET_ERROR; | |||
| } | |||
| // const auto index = tensor_cache->AddTensor(onnx_const_value.name(), tensor.release(), GRAPH_INPUT); | |||
| // MS_LOGD("add const tensor: %s, index %d", onnx_const_value.name().c_str(), index) | |||
| // TODO(wangzhe) why use GRAPH_INPUT other than CONST(GRAPH_INPUT will add index to graphInputs) | |||
| const auto index = tensor_cache->AddTensor(onnx_const_value.name(), tensor.release(), GRAPH_INPUT); | |||
| MS_LOG(DEBUG) << "add const tensor: " << onnx_const_value.name() << ", index " << index; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| // TODO(wangzhe) seems AddTensorCache should be renamed to prepare tensor to add to tensor_cache | |||
| STATUS OnnxModelParser::AddTensorCache(const onnx::ValueInfoProto &proto, schema::TensorT *tensor) { | |||
| auto data_type = GetDateTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(proto.type().tensor_type().elem_type())); | |||
| auto data_type = GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(proto.type().tensor_type().elem_type())); | |||
| if (data_type == kTypeUnknown) { | |||
| // MS_LOGE("not support onnx type %d", | |||
| // static_cast<onnx::TensorProto_DataType>(proto.type().tensor_type().elem_type())); | |||
| MS_LOG(ERROR) << "not support onnx type " | |||
| << static_cast<onnx::TensorProto_DataType>(proto.type().tensor_type().elem_type()); | |||
| return RET_ERROR; | |||
| } | |||
| tensor->dataType = data_type; | |||
| tensor->dims = GetDimsFromOnnxValue(proto); | |||
| tensor->format = schema::Format_NCHW; | |||
| tensor->nodeType = schema::NodeType_ValueNode; | |||
| // TODO(wangzhe) tensor->data and quantParams not set, should we need tensor_cache->AddTensor? | |||
| return RET_OK; | |||
| } | |||
| @@ -122,12 +124,14 @@ STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph, | |||
| auto ret = tensor_cache->FindTensor(input_value.name()); | |||
| if (ret < 0) { | |||
| std::unique_ptr<schema::TensorT> tensor(new schema::TensorT); | |||
| // TODO(wangzhe) why there is an addtensorCache? | |||
| if (AddTensorCache(input_value, tensor.get())) { | |||
| return RET_ERROR; | |||
| } | |||
| // TODO(wangzhe) why inputTensor is value and should be added into tensor_cache? | |||
| auto tensor_index = tensor_cache->AddTensor(input_value.name(), tensor.release(), GRAPH_INPUT); | |||
| graph->inputIndex.emplace_back(static_cast<uint32_t>(tensor_index)); | |||
| // MS_LOGD("input_value name: %s, graph input index: %d", input_value.name().c_str(), tensor_index); | |||
| MS_LOG(DEBUG) << "input_value name: " << input_value.name() << ", graph input index: " << tensor_index; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| @@ -140,9 +144,10 @@ STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, | |||
| if (AddTensorCache(output_value, tensor.get())) { | |||
| return RET_ERROR; | |||
| } | |||
| // TODO(wangzhe) why we need AddTensor at OutputTensor | |||
| auto tensor_index = tensor_cache->AddTensor(output_value.name(), tensor.release(), OP_OUTPUT); | |||
| graph->outputIndex.emplace_back(tensor_index); | |||
| // MS_LOGD("output_value name: %s, graph output index: %d", output_value.name().c_str(), tensor_index); | |||
| MS_LOG(DEBUG) << "output_value name: " << output_value.name() << ", graph output index: " << tensor_index; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -151,7 +156,6 @@ void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, cons | |||
| schema::MetaGraphT *graph, TensorCache *tensor_cache) { | |||
| std::unique_ptr<schema::CNodeT> dst_op_1(new schema::CNodeT); | |||
| dst_op_1->name = "Gemm_MatMul_" + onnx_node.output(0); | |||
| // dst_op_1->fmkType = FmkType_ONNX; | |||
| ParseOnnxNodeAttr(onnx_graph, onnx_node, "MatMul", dst_op_1.get()); | |||
| auto matmul_output_id = "Gemm_MatMul_" + onnx_node.output(0); | |||
| std::vector<string> matmul_inputs{onnx_node.input(0), onnx_node.input(1)}; | |||
| @@ -162,7 +166,6 @@ void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, cons | |||
| std::unique_ptr<schema::CNodeT> dst_op_2(new schema::CNodeT); | |||
| dst_op_2->name = "Gemm_BiasAdd_" + onnx_node.output(0); | |||
| // dst_op_2->fmkType = FmkType_ONNX; | |||
| ParseOnnxNodeAttr(onnx_graph, onnx_node, "BiasAdd", dst_op_2.get()); | |||
| std::vector<string> biasadd_inputs{matmul_output_id, onnx_node.input(2)}; | |||
| std::vector<string> biasadd_outputs{onnx_node.output(0)}; | |||
| @@ -181,7 +184,7 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, | |||
| [](const onnx::AttributeProto &attr) { return attr.name() == "shape"; }); | |||
| if (iter != onnx_node.attribute().end()) { | |||
| (void)shape.insert(shape.begin(), iter->ints().begin(), iter->ints().end()); | |||
| std::for_each(shape.begin(), shape.end(), [](int sh) { /*MS_LOGD("shape: %d", sh);*/ }); | |||
| std::for_each(shape.begin(), shape.end(), [](int sh) { MS_LOG(DEBUG) << "shape: " << sh; }); | |||
| } | |||
| tensor->dims = shape; | |||
| tensor->format = schema::Format_NUM_OF_FORMAT; | |||
| @@ -210,51 +213,50 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, | |||
| // todo: add * sizof(string) | |||
| data_size = data_count; | |||
| tensor->data.resize(data_size); | |||
| // MS_LOGD("tensor data size %lu, s: %lu", data_size, sizeof(iter->s().data())); | |||
| MS_LOG(DEBUG) << "tensor data size " << data_size << ", s: " << sizeof(iter->s().data()); | |||
| if (memcpy_s(tensor->data.data(), data_size, iter->s().data(), data_size) != 0) { | |||
| // MS_LOGE("memcpy_s failed") | |||
| MS_LOG(ERROR) << "memcpy_s failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } else { | |||
| // MS_LOGE("unsupported data type %d", tensor->dataType); | |||
| MS_LOG(ERROR) << "unsupported data type " << tensor->dataType; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| auto index = tensor_cache->AddTensor(onnx_node.output(0), tensor.release(), GRAPH_INPUT); | |||
| // MS_LOGD("add given tensor: %d", index); | |||
| MS_LOG(DEBUG) << "add given tensor: " << index; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *dst_op, | |||
| schema::TensorT *dst_tensor, | |||
| STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *dst_op, schema::TensorT *dst_tensor, | |||
| TensorCache *tensor_cache) { | |||
| // change op_type() to name(), that is unique | |||
| dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0); | |||
| // dst_op->fmkType = FmkType_ONNX; | |||
| // MS_LOGD("onnx op name %s, dst op name: %s, input size %d", onnx_node.op_type().c_str(), dst_op->name.c_str(), | |||
| // onnx_node.input_size()); | |||
| MS_LOG(DEBUG) << "onnx op name " << onnx_node.op_type() << ", dst op name: " << dst_op->name << ", input size " | |||
| << onnx_node.input_size(); | |||
| // get the real op type | |||
| SetOpQuantParams(onnx_graph, onnx_node, dst_op, dst_tensor, tensor_cache); | |||
| auto status = ParseOnnxNodeAttr(onnx_graph, onnx_node, onnx_node.op_type(), dst_op); | |||
| if (status != RET_OK) { | |||
| // MS_LOGE("parser onnx node attr failed"); | |||
| MS_LOG(ERROR) << "parser onnx node attr failed"; | |||
| return status; | |||
| } | |||
| // set op input index | |||
| std::vector<string> node_inputs; | |||
| (void)node_inputs.insert(node_inputs.begin(), onnx_node.input().begin(), onnx_node.input().end()); | |||
| if (SetOpInputIndex(node_inputs, dst_op, onnx_node, tensor_cache)) { | |||
| // MS_LOGE("SetOpInputIndex failed"); | |||
| MS_LOG(ERROR) << "SetOpInputIndex failed"; | |||
| return RET_ERROR; | |||
| } | |||
| // set op output index | |||
| std::vector<string> node_outputs; | |||
| (void)node_outputs.insert(node_outputs.begin(), onnx_node.output().begin(), onnx_node.output().end()); | |||
| if (SetOpOutputIndex(node_outputs, dst_op, tensor_cache) != RET_OK) { | |||
| // MS_LOGE("SetOpOutputIndex failed"); | |||
| MS_LOG(ERROR) << "SetOpOutputIndex failed"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| @@ -286,7 +288,7 @@ void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const | |||
| for (const auto &node : quant_node) { | |||
| std::unique_ptr<schema::QuantParamT> quant_param(new (std::nothrow) schema::QuantParamT()); | |||
| if (quant_param == nullptr) { | |||
| // MS_LOGE("new QuantParamT failed, node: %s", dst_op->name.c_str()); | |||
| MS_LOG(ERROR) << "new QuantParamT failed, node: " << dst_op->name; | |||
| return; | |||
| } | |||
| int argNum = 0; | |||
| @@ -322,7 +324,7 @@ STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, co | |||
| const string &onnx_op_type, schema::CNodeT *dst_op) { | |||
| auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_op_type); | |||
| if (node_parser == nullptr) { | |||
| // MS_LOGE("not find %s, node parser is nullptr", onnx_op_type.c_str()); | |||
| MS_LOG(EXCEPTION) << "not find " << onnx_op_type << ", node parser is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| return node_parser->Parse(onnx_graph, onnx_node, dst_op); | |||
| @@ -332,26 +334,32 @@ STATUS OnnxModelParser::SetOpInputIndex(const std::vector<string> &node_inputs, | |||
| const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) { | |||
| schema::Format format = schema::Format_MAX; | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| if (onnx_node_attr.name() == "order") { | |||
| if (onnx_node_attr.name() == "order") { // do we need this code? onnx doc don't have order attr | |||
| MS_LOG(EXCEPTION) << "find order attr"; | |||
| if (onnx_node_attr.s() == "NHWC") { | |||
| format = schema::Format_NHWC; | |||
| } else { | |||
| // MS_LOGE("Unsupported format: %s", onnx_node_attr.s().c_str()); | |||
| MS_LOG(ERROR) << "Unsupported format: " << onnx_node_attr.s(); | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| } | |||
| for (const auto &onnx_node_input : node_inputs) { | |||
| auto index = tensor_cache->FindTensor(onnx_node_input); | |||
| if (index < 0) { | |||
| // MS_LOG(ERROR) << onnx_node.name() << " input " << onnx_node_input << " index in tensor_cache " << index; | |||
| if (index < 0) { // TODO(wangzhe) can this be ignored? because it's no use | |||
| /* | |||
| std::unique_ptr<schema::TensorT> tensor(new schema::TensorT); | |||
| index = tensor_cache->AddTensor(onnx_node_input, tensor.release(), OP_OUTPUT); | |||
| */ | |||
| MS_LOG(EXCEPTION) << "input " << onnx_node_input << " of node " << onnx_node.name() << " can't be found"; | |||
| // MS_LOG(INFO) << "new index: " << index; | |||
| } | |||
| if (format != schema::Format_MAX) { | |||
| if (format != schema::Format_MAX) { // TODO(wangzhe) also this | |||
| auto inTensor = tensor_cache->GetCachedTensor().at(index); | |||
| inTensor->format = format; | |||
| } | |||
| // MS_LOGD("node: %s, input index: %d", onnx_node_input.c_str(), index); | |||
| MS_LOG(DEBUG) << "node: " << onnx_node_input << ", input index: " << index; | |||
| dst_op->inputIndex.emplace_back(index); | |||
| } | |||
| return RET_OK; | |||
| @@ -362,23 +370,30 @@ STATUS OnnxModelParser::SetOpOutputIndex(const std::vector<string> &node_outputs | |||
| for (const auto &onnx_node_output : node_outputs) { | |||
| auto index = tensor_cache->FindTensor(onnx_node_output); | |||
| if (index < 0) { | |||
| MS_LOG(INFO) << "output of node " << dst_op->name << " not in tensor_cache, creating"; | |||
| MS_LOG(INFO) << "total " << node_outputs.size() << " outputs"; | |||
| std::unique_ptr<schema::TensorT> tensor(new schema::TensorT); | |||
| // GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(proto.type().tensor_type().elem_type())); | |||
| // tensor->dataType = ; | |||
| // tensor->dims = tflite_tensor->shape; | |||
| tensor->nodeType = schema::NodeType_Parameter; | |||
| index = tensor_cache->AddTensor(onnx_node_output, tensor.release(), OP_OUTPUT); | |||
| } | |||
| // MS_LOGD("node: %s, input index: %d", onnx_node_output.c_str(), index); | |||
| MS_LOG(DEBUG) << "node: " << onnx_node_output << ", input index: " << index; | |||
| dst_op->outputIndex.emplace_back(index); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_value, | |||
| schema::TensorT *tensor) { | |||
| STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_value, schema::TensorT *tensor) { | |||
| size_t data_count = 1; | |||
| std::for_each(tensor->dims.begin(), tensor->dims.end(), [&data_count](int dim) { data_count *= dim; }); | |||
| size_t data_size = 0; | |||
| const void *tensor_data = nullptr; | |||
| switch (tensor->dataType) { | |||
| case kNumberTypeFloat: | |||
| case kNumberTypeFloat32: | |||
| data_size = data_count * sizeof(float); | |||
| if (onnx_const_value.float_data_size() == 0) { | |||
| tensor_data = onnx_const_value.raw_data().data(); | |||
| @@ -408,12 +423,12 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v | |||
| tensor_data = onnx_const_value.raw_data().data(); | |||
| break; | |||
| default: | |||
| // MS_LOGE("unsupported data type %d", tensor->dataType); | |||
| MS_LOG(ERROR) << "unsupported data type " << tensor->dataType; | |||
| return RET_ERROR; | |||
| } | |||
| tensor->data.resize(data_size); | |||
| if (memcpy_s(static_cast<void *>(tensor->data.data()), data_size, tensor_data, data_size) != 0) { | |||
| // MS_LOGE("memcpy_s failed") | |||
| MS_LOG(ERROR) << "memcpy_s failed"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| @@ -441,36 +456,37 @@ void OnnxModelParser::FindGraphInputAndConst(const onnx::GraphProto &onnx_graph) | |||
| } | |||
| } | |||
| MetaGraphT *OnnxModelParser::Parse(const std::string &modelFile, const std::string &weightFile) { | |||
| MetaGraphT *OnnxModelParser::Parse(const std::string &modelFile, const std::string &weightFile, | |||
| const QuantType &quantType) { | |||
| if (ValidateFileStr(modelFile, ".onnx") != RET_OK) { | |||
| // MS_LOGE("Input illegal: modelFile must be *.onnx"); | |||
| MS_LOG(ERROR) << "Input illegal: modelFile must be *.onnx"; | |||
| return nullptr; | |||
| } | |||
| std::unique_ptr<schema::MetaGraphT> dst_graph(new schema::MetaGraphT()); | |||
| onnx::ModelProto onnx_model; | |||
| if (ReadOnnxModelFromBinary(modelFile, &onnx_model) != RET_OK) { | |||
| // MS_LOGE("read onnx model fail"); | |||
| MS_LOG(ERROR) << "read onnx model fail"; | |||
| return nullptr; | |||
| } | |||
| const onnx::GraphProto &onnx_graph = onnx_model.graph(); | |||
| // MS_LOGI("model producer name: %s, graph name: %s", onnx_model.producer_name().c_str(), onnx_graph.name().c_str()); | |||
| MS_LOG(INFO) << "model producer name: " << onnx_model.producer_name() << ", graph name: " << onnx_graph.name(); | |||
| TensorCache tensor_cache; | |||
| dst_graph->name = onnx_graph.name(); | |||
| // dst_graph->name = onnx_graph.name(); // this is not used | |||
| // find out input names and const names | |||
| FindGraphInputAndConst(onnx_graph); | |||
| // set const tensor | |||
| if (SetGraphConstTensor(onnx_graph, &tensor_cache)) { | |||
| // MS_LOGE("SetGraphConstTensor failed"); | |||
| MS_LOG(ERROR) << "SetGraphConstTensor failed"; | |||
| return nullptr; | |||
| } | |||
| // init onnx model graph input tensor | |||
| if (SetGraphInputTensor(onnx_graph, dst_graph.get(), &tensor_cache)) { | |||
| // MS_LOGE("SetGraphInputTensor failed"); | |||
| MS_LOG(ERROR) << "SetGraphInputTensor failed"; | |||
| return nullptr; | |||
| } | |||
| // init onnx model graph output tensor | |||
| if (SetGraphOutputTensor(onnx_graph, dst_graph.get(), &tensor_cache)) { | |||
| // MS_LOGE("SetGraphOutputTensor failed"); | |||
| MS_LOG(ERROR) << "SetGraphOutputTensor failed"; | |||
| return nullptr; | |||
| } | |||
| // init op node input/output tensor, and dst_op attr | |||
| @@ -481,7 +497,7 @@ MetaGraphT *OnnxModelParser::Parse(const std::string &modelFile, const std::stri | |||
| } else if (onnx_node.op_type() == "Int8GivenIntTensorFill" || onnx_node.op_type() == "Int8GivenTensorFill") { | |||
| auto status = ParseOnnxGivenFillNode(onnx_node, &tensor_cache); | |||
| if (status != RET_OK) { | |||
| // MS_LOGE("ParseOnnxGivenFillNode failed: %d", status); | |||
| MS_LOG(ERROR) << "ParseOnnxGivenFillNode failed: " << status; | |||
| return nullptr; | |||
| } | |||
| continue; | |||
| @@ -489,18 +505,16 @@ MetaGraphT *OnnxModelParser::Parse(const std::string &modelFile, const std::stri | |||
| std::unique_ptr<schema::CNodeT> dst_op(new schema::CNodeT); | |||
| std::unique_ptr<schema::TensorT> dst_tensor(new schema::TensorT); | |||
| if (ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache)) { | |||
| // MS_LOGE("parse node %s failed", onnx_node.op_type().c_str()) | |||
| auto status = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed"; | |||
| return nullptr; | |||
| } | |||
| dst_graph->nodes.emplace_back(std::move(dst_op)); | |||
| } | |||
| SetAllTensors(tensor_cache, dst_graph.get()); | |||
| dst_graph->mempoolSize = 0; | |||
| dst_graph->name = GetModelName(modelFile); | |||
| return dst_graph.release(); | |||
| // return Fb2Anf(dst_graph.release()); | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -27,9 +27,10 @@ | |||
| #include <memory> | |||
| #include <set> | |||
| #include "securec/include/securec.h" | |||
| #include "mindspore/lite/tools/converter/model_parser.h" | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/model_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/common/tensor_util.h" | |||
| #include "tools/converter/parser/onnx/onnx.pb.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -41,30 +42,24 @@ class OnnxModelParser : public ModelParser { | |||
| const QuantType &quantType = QuantType_QUANT_NONE) override; | |||
| private: | |||
| TypeId GetDateTypeFromOnnx(onnx::TensorProto_DataType onnx_type); | |||
| TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type); | |||
| std::vector<int32_t> GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value); | |||
| STATUS ReadOnnxModelFromBinary(const std::string &modelFile, google::protobuf::Message *model_proto); | |||
| STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache); | |||
| STATUS SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph, TensorCache *tensor_cache); | |||
| STATUS SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph, TensorCache *tensor_cache); | |||
| STATUS AddTensorCache(const onnx::ValueInfoProto &proto, schema::TensorT *tensor); | |||
| STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *dst_op, | |||
| schema::TensorT *dst_tensor, TensorCache *tensor_cache); | |||
| void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node, | |||
| schema::MetaGraphT *graph, | |||
| TensorCache *tensor_cache); | |||
| STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *dst_op, schema::TensorT *dst_tensor, TensorCache *tensor_cache); | |||
| void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::MetaGraphT *graph, TensorCache *tensor_cache); | |||
| STATUS ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, TensorCache *tensor_cache); | |||
| STATUS ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| const string &onnx_op_type, schema::CNodeT *dst_op); | |||
| void SetOpQuantParams(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *dst_op, | |||
| schema::TensorT *dst_tensor, TensorCache *tensor_cache); | |||
| STATUS SetOpInputIndex(const std::vector<string> &node_inputs, | |||
| schema::CNodeT *dst_op, | |||
| const onnx::NodeProto &onnx_node, | |||
| TensorCache *tensor_cache); | |||
| STATUS SetOpInputIndex(const std::vector<string> &node_inputs, schema::CNodeT *dst_op, | |||
| const onnx::NodeProto &onnx_node, TensorCache *tensor_cache); | |||
| STATUS SetOpOutputIndex(const std::vector<string> &node_outputs, schema::CNodeT *dst_op, TensorCache *tensor_cache); | |||
| STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_init_value, schema::TensorT *tensor); | |||
| STATUS SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *graphDef); | |||
| @@ -78,4 +73,3 @@ class OnnxModelParser : public ModelParser { | |||
| } // namespace mindspore | |||
| #endif // MS_ONNX_MODEL_PARSER_H | |||
| @@ -14,7 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -19,9 +19,9 @@ | |||
| #include <string> | |||
| #include "google/protobuf/message.h" | |||
| #include "mindspore/lite/tools/converter/proto/onnx.pb.h" | |||
| #include "tools/converter/parser/onnx/onnx.pb.h" | |||
| #include "tools/common/node_util.h" | |||
| #include "mindspore/lite/schema/inner/model_generated.h" | |||
| #include "schema/inner/model_generated.h" | |||
| // using namespace std; | |||
| @@ -14,7 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include <string> | |||
| namespace mindspore { | |||
| @@ -33,13 +33,14 @@ OnnxNodeParser *OnnxNodeParserRegistry::GetNodeParser(const std::string &name) { | |||
| if (it != parsers.end()) { | |||
| return it->second; | |||
| } | |||
| /* should not support vague name, otherwise may get wrong parser. ex. PRelu and Relu | |||
| for (auto const &i : parsers) { | |||
| if (name.find(i.first) != std::string::npos) { | |||
| return i.second; | |||
| } | |||
| } | |||
| */ | |||
| return nullptr; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -19,8 +19,7 @@ | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include "mindspore/lite/tools/converter/proto/onnx.pb.h" | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -15,12 +15,13 @@ | |||
| */ | |||
| #include <memory> | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_pad_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxPadParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| unique_ptr<schema::PadT> attr(new schema::PadT()); | |||
| MS_LOG(DEBUG) << "onnx PadParser"; | |||
| std::unique_ptr<schema::PadT> attr(new schema::PadT()); | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "pads") { | |||
| @@ -33,11 +34,11 @@ STATUS OnnxPadParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node | |||
| } else if (attribute_name == "mode") { | |||
| const auto &mode = onnx_node_attr.s(); | |||
| if (mode == "constant") { | |||
| attr->paddingmode = schema::PaddingMode_CONSTANT; | |||
| attr->paddingMode = schema::PaddingMode_CONSTANT; | |||
| } else if (mode == "reflect") { | |||
| attr->paddingmode = schema::PaddingMode_REFLECT; | |||
| attr->paddingMode = schema::PaddingMode_REFLECT; | |||
| } else if (mode == "edge") { | |||
| attr->paddingmode = schema::PaddingMode_SYMMETRIC; | |||
| attr->paddingMode = schema::PaddingMode_SYMMETRIC; | |||
| } | |||
| } | |||
| } | |||
| @@ -17,8 +17,8 @@ | |||
| #ifndef MS_ONNX_LRN_PARSER_H | |||
| #define MS_ONNX_LRN_PARSER_H | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -15,12 +15,13 @@ | |||
| */ | |||
| #include <memory> | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_pool_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| unique_ptr<schema::PoolingT> attr(new schema::PoolingT()); | |||
| MS_LOG(DEBUG) << "onnx PoolParser"; | |||
| std::unique_ptr<schema::PoolingT> attr(new schema::PoolingT()); | |||
| const auto &pool_type = onnx_node.op_type(); | |||
| if (pool_type == "MaxPool") { | |||
| @@ -41,6 +42,8 @@ STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||
| } | |||
| attr->roundMode = schema::RoundMode_FLOOR; | |||
| attr->strideW = 1; | |||
| attr->strideH = 1; | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "kernel_shape") { | |||
| @@ -17,8 +17,8 @@ | |||
| #ifndef MS_ONNX_POOL_PARSER_H | |||
| #define MS_ONNX_POOL_PARSER_H | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -15,14 +15,15 @@ | |||
| */ | |||
| #include <memory> | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_reduce_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxReduceParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| unique_ptr<schema::ReduceT> attr(new schema::ReduceT()); | |||
| MS_LOG(DEBUG) << "onnx ReduceParser"; | |||
| std::unique_ptr<schema::ReduceT> attr(new schema::ReduceT()); | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "axes") { | |||
| @@ -17,8 +17,8 @@ | |||
| #ifndef MS_ONNX_REDUCE_PARSER_H | |||
| #define MS_ONNX_REDUCE_PARSER_H | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -16,12 +16,13 @@ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_relu_parser.h" | |||
| #include "securec/include/securec.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| unique_ptr<schema::ActivationT> attr(new schema::ActivationT()); | |||
| MS_LOG(DEBUG) << "onnx ReluParser"; | |||
| std::unique_ptr<schema::ActivationT> attr(new schema::ActivationT()); | |||
| const auto &relu_type = onnx_node.op_type(); | |||
| if (relu_type == "Relu") { | |||
| attr->type = schema::ActivationType_RELU; | |||
| @@ -30,44 +31,52 @@ STATUS OnnxReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||
| } | |||
| if (op != nullptr) { | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| op->primitive->value.type = schema::PrimitiveType_Activation; | |||
| op->primitive->value.value = attr.release(); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node, | |||
| STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "onnx PReluParser"; | |||
| if (onnx_node.input_size() != 2) { | |||
| // MS_LOGE("input num is not 2") | |||
| MS_LOG(ERROR) << "input num is not 2"; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| unique_ptr<schema::PreluT> attr(new schema::PreluT()); | |||
| std::unique_ptr<schema::CaffePReLUT> attr(new schema::CaffePReLUT()); | |||
| std::vector<onnx::TensorProto> params; | |||
| for (int i = 0; i < onnx_node.input_size(); ++i) { | |||
| const auto &input_name = onnx_node.input(i); | |||
| for ( const auto &it : onnx_graph.initializer() ) { | |||
| if (it.name() == "input_name") { | |||
| params.push_back(it); | |||
| break; | |||
| } | |||
| const auto &input_name = onnx_node.input(1); | |||
| for (const auto &it : onnx_graph.initializer()) { | |||
| if (it.name() == input_name) { | |||
| params.push_back(it); | |||
| break; | |||
| } | |||
| } | |||
| const onnx::TensorProto *slope = ¶ms[0]; | |||
| if (slope == nullptr) { | |||
| // MS_LOGE("input error") | |||
| MS_LOG(ERROR) << "input error"; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| const auto slope_raw_data = reinterpret_cast<const float *>(slope->raw_data().data()); | |||
| const int64_t slope_size = slope->raw_data().size() / sizeof(float); | |||
| if (memcpy_s(attr->slope.data(), slope_size * sizeof(float), slope_raw_data, slope_size * sizeof(float)) != 0) { | |||
| // MS_LOGE("memcpy_s failed") | |||
| return RET_ERROR; | |||
| if (slope_size == 1) { | |||
| attr->slope.push_back(*slope_raw_data); | |||
| attr->channelShared = true; | |||
| } else { // TODO(wangzhe) we don't check input tensor's channel size, this may cause problem | |||
| attr->slope.resize(slope_size); | |||
| attr->channelShared = false; | |||
| if (memcpy_s(attr->slope.data(), slope_size * sizeof(float), slope_raw_data, slope_size * sizeof(float)) != 0) { | |||
| MS_LOG(ERROR) << "memcpy_s failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| if (op != nullptr) { | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| op->primitive->value.type = schema::PrimitiveType_Prelu; | |||
| op->primitive->value.type = schema::PrimitiveType_CaffePReLU; | |||
| op->primitive->value.value = attr.release(); | |||
| } | |||
| return RET_OK; | |||
| @@ -75,7 +84,6 @@ STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| OnnxNodeRegistrar g_onnxReluParser("Relu", new OnnxReluParser()); | |||
| OnnxNodeRegistrar g_onnxLeakyReluParser("LeakyRelu", new OnnxLeakeyReluParser()); | |||
| OnnxNodeRegistrar g_onnxPReluParser("Prelu", new OnnxPReluParser()); | |||
| OnnxNodeRegistrar g_onnxPReluParser("PRelu", new OnnxPReluParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -17,8 +17,8 @@ | |||
| #ifndef MS_ONNX_RELU_PARSER_H | |||
| #define MS_ONNX_RELU_PARSER_H | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -16,17 +16,17 @@ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_reshape_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node, | |||
| STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| unique_ptr<schema::ReshapeT> attr(new schema::ReshapeT()); | |||
| attr->format = schema::Format_NHWC; | |||
| MS_LOG(DEBUG) << "onnx ReshapeParser"; | |||
| std::unique_ptr<schema::ReshapeT> attr(new schema::ReshapeT()); | |||
| attr->format = schema::Format_NCHW; | |||
| std::vector<onnx::TensorProto> params; | |||
| // TODO(wangzhe) shape may also come from other op, there need refactor to introduce tensor_cache | |||
| for (int i = 0; i < onnx_node.input_size(); ++i) { | |||
| const auto &input_name = onnx_node.input(i); | |||
| for (const auto &it : onnx_graph.initializer()) { | |||
| @@ -37,16 +37,16 @@ STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| } | |||
| } | |||
| if (params.empty()) { | |||
| return RET_OK; | |||
| } | |||
| if (params.size() != 1) { | |||
| // MS_LOGE("input num is ,not equal 1", params.size()) | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| MS_LOG(DEBUG) << "shape from another op other than const initializer"; | |||
| } else { | |||
| if (params.size() != 1) { | |||
| MS_LOG(ERROR) << "shape param num is " << params.size() << ", not equal to 1"; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| auto pre_shape = params[0]; | |||
| for (int i = 0; i < pre_shape.dims_size(); ++i) { | |||
| attr->shape.emplace_back(params[0].dims(i)); | |||
| for (int i = 0; i < params[0].int64_data_size(); ++i) { | |||
| attr->shape.emplace_back(params[0].int64_data(i)); | |||
| } | |||
| } | |||
| if (op != nullptr) { | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| @@ -59,4 +59,3 @@ STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| OnnxNodeRegistrar g_onnxReshapeParser("Reshape", new OnnxReshapeParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -17,8 +17,8 @@ | |||
| #ifndef MS_ONNX_RESHAPE_PARSER_H | |||
| #define MS_ONNX_RESHAPE_PARSER_H | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -15,17 +15,19 @@ | |||
| */ | |||
| #include <memory> | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_shape_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxShapeParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "onnx ShapeParser"; | |||
| if (op != nullptr) { | |||
| std::unique_ptr<schema::ShapeT> attr(new schema::ShapeT()); | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| op->primitive->value.type = schema::PrimitiveType_Shape; | |||
| op->primitive->value.value = nullptr; | |||
| op->primitive->value.value = attr.release(); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -17,8 +17,8 @@ | |||
| #ifndef MS_ONNX_SHAPE_PARSER_H | |||
| #define MS_ONNX_SHAPE_PARSER_H | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -15,14 +15,15 @@ | |||
| */ | |||
| #include <memory> | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_sigmoid_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxSigmoidParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| unique_ptr<schema::ActivationT> attr(new schema::ActivationT()); | |||
| MS_LOG(DEBUG) << "onnx SigmoidParser"; | |||
| std::unique_ptr<schema::ActivationT> attr(new schema::ActivationT()); | |||
| attr->type = schema::ActivationType_SIGMOID; | |||
| if (op != nullptr) { | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| @@ -17,8 +17,8 @@ | |||
| #ifndef MS_ONNX_SIGMOID_PARSER_H | |||
| #define MS_ONNX_SIGMOID_PARSER_H | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -15,14 +15,15 @@ | |||
| */ | |||
| #include <memory> | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_slice_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| unique_ptr<schema::SliceT> attr(new schema::SliceT()); | |||
| MS_LOG(DEBUG) << "onnx SliceParser"; | |||
| std::unique_ptr<schema::SliceT> attr(new schema::SliceT()); | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto& attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "starts") { | |||
| @@ -17,8 +17,8 @@ | |||
| #ifndef MS_ONNX_SLICE_PARSER_H | |||
| #define MS_ONNX_SLICE_PARSER_H | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -15,14 +15,15 @@ | |||
| */ | |||
| #include <memory> | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_softmax_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxSoftMaxParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| unique_ptr<schema::SoftMaxT> attr(new schema::SoftMaxT()); | |||
| MS_LOG(DEBUG) << "onnx SoftMaxParser"; | |||
| std::unique_ptr<schema::SoftMaxT> attr(new schema::SoftMaxT()); | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto& attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "axis") { | |||
| @@ -17,8 +17,8 @@ | |||
| #ifndef MS_ONNX_SOFTMAX_PARSER_H | |||
| #define MS_ONNX_SOFTMAX_PARSER_H | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -15,14 +15,14 @@ | |||
| */ | |||
| #include <memory> | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_space_to_depth_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxSPaceToDepthParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node, | |||
| STATUS OnnxSpaceToDepthParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| unique_ptr<schema::SpaceToDepthT> attr(new schema::SpaceToDepthT()); | |||
| MS_LOG(DEBUG) << "onnx SpaceToDepthParser"; | |||
| std::unique_ptr<schema::SpaceToDepthT> attr(new schema::SpaceToDepthT()); | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "blocksize") { | |||
| @@ -37,7 +37,6 @@ STATUS OnnxSPaceToDepthParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| return RET_OK; | |||
| } | |||
| OnnxNodeRegistrar g_onnxSpaceToDepthParser("SpaceToDepth", new OnnxSPaceToDepthParser()); | |||
| OnnxNodeRegistrar g_onnxSpaceToDepthParser("SpaceToDepth", new OnnxSpaceToDepthParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -17,14 +17,14 @@ | |||
| #ifndef MS_ONNX_SPACE_TO_DEPTH_PARSER_H | |||
| #define MS_ONNX_SPACE_TO_DEPTH_PARSER_H | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class OnnxSPaceToDepthParser : public OnnxNodeParser { | |||
| class OnnxSpaceToDepthParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxSPaceToDepthParser() : OnnxNodeParser("SpaceToDepth") {} | |||
| OnnxSpaceToDepthParser() : OnnxNodeParser("SpaceToDepth") {} | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| }; | |||
| } // namespace lite | |||
| @@ -15,14 +15,15 @@ | |||
| */ | |||
| #include <memory> | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_squeeze_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxSqueezeParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| unique_ptr<schema::SqueezeT> attr(new schema::SqueezeT()); | |||
| MS_LOG(DEBUG) << "onnx SqueezeParser"; | |||
| std::unique_ptr<schema::SqueezeT> attr(new schema::SqueezeT()); | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "axes") { | |||
| @@ -17,8 +17,8 @@ | |||
| #ifndef MS_ONNX_SQUEEZE_PARSER_H | |||
| #define MS_ONNX_SQUEEZE_PARSER_H | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -15,15 +15,17 @@ | |||
| */ | |||
| #include <memory> | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_tile_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxTileParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "onnx TileParser"; | |||
| if (op != nullptr) { | |||
| std::unique_ptr<schema::TileT> attr(new schema::TileT()); | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| op->primitive->value.type = schema::PrimitiveType_Tile; | |||
| op->primitive->value.value = nullptr; | |||
| op->primitive->value.value = attr.release(); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -31,4 +33,3 @@ STATUS OnnxTileParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||
| OnnxNodeRegistrar g_onnxTileParser("Tile", new OnnxTileParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -17,8 +17,8 @@ | |||
| #ifndef MS_ONNX_TILE_PARSER_H | |||
| #define MS_ONNX_TILE_PARSER_H | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -15,14 +15,15 @@ | |||
| */ | |||
| #include <memory> | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_transpose_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxTransposeParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| unique_ptr<schema::TransposeT> attr(new schema::TransposeT()); | |||
| MS_LOG(DEBUG) << "onnx TransposeParser"; | |||
| std::unique_ptr<schema::TransposeT> attr(new schema::TransposeT()); | |||
| attr->conjugate = false; | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| @@ -17,8 +17,8 @@ | |||
| #ifndef MS_ONNX_TRANSPOSE_PARSER_H | |||
| #define MS_ONNX_TRANSPOSE_PARSER_H | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -15,14 +15,15 @@ | |||
| */ | |||
| #include <memory> | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_unsample_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_unsample_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxUpsampleParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| unique_ptr<schema::UpsampleT> attr(new schema::UpsampleT()); | |||
| MS_LOG(DEBUG) << "onnx UpsampleParser"; | |||
| std::unique_ptr<schema::UpsampleT> attr(new schema::UpsampleT()); | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "mode") { | |||
| @@ -17,8 +17,8 @@ | |||
| #ifndef MS_ONNX_UPSAMPLE_PARSER_H | |||
| #define MS_ONNX_UPSAMPLE_PARSER_H | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -15,14 +15,15 @@ | |||
| */ | |||
| #include <memory> | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_unsqueeze_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxUnSqueezeParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| unique_ptr<schema::UnsqueezeT> attr(new schema::UnsqueezeT()); | |||
| MS_LOG(DEBUG) << "onnx UnSqueezeParser"; | |||
| std::unique_ptr<schema::UnsqueezeT> attr(new schema::UnsqueezeT()); | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "axes") { | |||
| @@ -17,8 +17,8 @@ | |||
| #ifndef MS_ONNX_UNSQUEEZE_PARSER_H | |||
| #define MS_ONNX_UNSQUEEZE_PARSER_H | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -15,13 +15,14 @@ | |||
| */ | |||
| #include <memory> | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_unuseful_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_unuseful_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxUnusefulNodeParser::Parse(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "onnx UnusefulNodeParser"; | |||
| if (op != nullptr) { | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (onnx_node.op_type() == "Int8Quantize") { | |||
| @@ -17,8 +17,8 @@ | |||
| #ifndef MS_ONNX_UNUSEFUL_PARSER_H | |||
| #define MS_ONNX_UNUSEFUL_PARSER_H | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||