diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index 4b9245a3ef..d3dc1bb4dd 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -219,6 +219,10 @@ union PrimitiveType { Sgd, Adam, GroupConv2DGradInput, + Loop, + NonMaxSuppression, + InstanceNorm, + Identity, } enum QuantType: int { @@ -250,6 +254,7 @@ table MetaGraph { mempoolSize: uint; nodes: [CNode]; allTensors: [Tensor]; // weight + input + output + subGraph : [MetaGraph]; } root_type MetaGraph; diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 55f7876374..ec50dcfbee 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -18,8 +18,28 @@ namespace mindspore.schema; enum ResizeMethod: byte { UNKNOW = -1, - BILINEAR = 0, - NEAREST_NEIGHBOR = 1 + LINEAR = 0, + NEAREST = 1, + CUBIC = 2 +} + +enum CoordinateTransformMode: byte { + COMMON = 0, + HALF_PIXEL = 1, + PYTORCH_HALF_PIXEL = 2, + TF_HALF_PIXEL = 3, + TF_CROP_AND_RESIZE = 4, + ALIGN_CORNERS = 5, + ASYMMETRIC = 6, + ALIGN_CORNERS_WITH_HALF_PIEXL = 7 +} + +enum NearestMode : byte { + NORMAL = 0, + ROUND_HALF_DOWN = 1, + ROUND_HALF_UP = 2, + FLOOR = 3, + CEIL = 4 } enum Format : int { @@ -376,8 +396,13 @@ table Resize { method: ResizeMethod; newHeight: long; newWidth: long; - alignCorners: bool = false; + alignCorners: bool = false; // DEPRECATED IN FUTURE: use 'coordinateTransformMode' instead. preserveAspectRatio: bool = false; + coordinateTransformMode : CoordinateTransformMode; + cubicCoeff : float; + excludeOutside : int; + extrapolationValue : float = 0; + nearestMode : NearestMode; } table DetectionPostProcess { @@ -1054,3 +1079,21 @@ table FftReal { table FftImag { } + +table NonMaxSuppression { + maxOutBoxPerClass : int = 0; + iouThreshold : float = 0; + scoreThreshold : float = 0; + centerPointBox : int = 0; +} + +table InstanceNorm { + epsilon : float = 0.00001; +} + +table Loop { + subGraphIndex : int; +} + +table Identity { +} diff --git a/mindspore/lite/src/ops/resize.cc b/mindspore/lite/src/ops/resize.cc index 7b23af5428..b509190923 100644 --- a/mindspore/lite/src/ops/resize.cc +++ b/mindspore/lite/src/ops/resize.cc @@ -51,9 +51,9 @@ int Resize::UnPackAttr(const Primitive &prim, const std::vector &inp if (this->primitive_->value.value == nullptr) { auto attr = new (std::nothrow) schema::ResizeT(); if (prim.instance_name() == "ResizeNearestNeighbor") { - attr->method = schema::ResizeMethod_NEAREST_NEIGHBOR; + attr->method = schema::ResizeMethod_NEAREST; } else if (prim.instance_name() == "ResizeBilinear") { - attr->method = schema::ResizeMethod_BILINEAR; + attr->method = schema::ResizeMethod_LINEAR; } else { MS_LOG(ERROR) << "wrong resize type"; return RET_ERROR; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/resize_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/resize_base.cc index 954d7897f2..c62028c0bf 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/resize_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/resize_base.cc @@ -41,8 +41,8 @@ int ResizeBaseCPUKernel::CheckParameters() { return RET_NULL_PTR; } method_ = parameter->method_; - if (method_ != static_cast(schema::ResizeMethod_BILINEAR) && - method_ != static_cast(schema::ResizeMethod_NEAREST_NEIGHBOR)) { + if (method_ != static_cast(schema::ResizeMethod_LINEAR) && + method_ != static_cast(schema::ResizeMethod_NEAREST)) { MS_LOG(ERROR) << "Resize method should be bilinear or nearest_neighbor, but got " << method_; return RET_INVALID_OP_ATTR; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/resize.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/resize.cc index 5b2eaa1341..cff4535ee4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/resize.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/resize.cc @@ -14,11 +14,11 @@ * limitations under the License. */ -#include #include "src/runtime/kernel/arm/fp32/resize.h" -#include "schema/model_generated.h" -#include "nnacl/fp32/resize.h" +#include #include "include/errorcode.h" +#include "nnacl/fp32/resize.h" +#include "schema/model_generated.h" #include "src/runtime/runtime_api.h" using mindspore::kernel::KERNEL_ARCH::kCPU; @@ -41,7 +41,7 @@ int ResizeCPUKernel::Init() { int ResizeCPUKernel::ReSize() { int ret = RET_OK; - if (method_ == static_cast(schema::ResizeMethod_BILINEAR)) { + if (method_ == static_cast(schema::ResizeMethod_LINEAR)) { FreeTmpBuffer(); ret = MallocTmpBuffer(); if (ret != RET_OK) { @@ -162,7 +162,7 @@ int ResizeCPUKernel::RunImpl(int task_id) { int ret = 0; switch (method_) { - case static_cast(schema::ResizeMethod_BILINEAR): { + case static_cast(schema::ResizeMethod_LINEAR): { int n_h_begin, n_h_end; int n = out_tensors_.at(0)->shape()[0]; int h = new_height_; @@ -178,7 +178,7 @@ int ResizeCPUKernel::RunImpl(int task_id) { break; } - case static_cast(schema::ResizeMethod_NEAREST_NEIGHBOR): { + case static_cast(schema::ResizeMethod_NEAREST): { if (in_tensors_.size() == lite::kDoubleNum && !const_shape_) { auto out_shape = in_tensors_.at(1); auto data = reinterpret_cast(out_shape->MutableData()); diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/resize_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/resize_int8.cc index 87f3e91395..be9bdb3785 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/resize_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/resize_int8.cc @@ -14,12 +14,12 @@ * limitations under the License. */ +#include "src/runtime/kernel/arm/int8/resize_int8.h" #include -#include "src/kernel_registry.h" +#include "include/errorcode.h" #include "nnacl/int8/resize.h" #include "schema/model_generated.h" -#include "include/errorcode.h" -#include "src/runtime/kernel/arm/int8/resize_int8.h" +#include "src/kernel_registry.h" #include "src/runtime/runtime_api.h" using mindspore::kernel::KERNEL_ARCH::kCPU; @@ -84,7 +84,7 @@ int ResizeInt8CPUKernel::RunImpl(int task_id) { int ret = 0; switch (method_) { - case static_cast(schema::ResizeMethod_BILINEAR): { + case static_cast(schema::ResizeMethod_LINEAR): { if (quant_in_->zp_ == 0) { ret = ResizeBilinearInt8(input_data, output_data, input_shape.data(), out_tensors_[0]->shape().data(), align_corners_, quant_in_, quant_out_, multiplier_, task_id, context_->thread_num_); @@ -95,7 +95,7 @@ int ResizeInt8CPUKernel::RunImpl(int task_id) { } break; } - case static_cast(schema::ResizeMethod_NEAREST_NEIGHBOR): { + case static_cast(schema::ResizeMethod_NEAREST): { bool same_zp = quant_in_->zp_ == quant_out_->zp_; bool same_scale = abs(quant_out_->scale_ - quant_in_->scale_) < 1e-6; if (same_zp && same_scale) { diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/resize.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/resize.cc index 8c8b298d7c..5495be6907 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/resize.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/resize.cc @@ -14,12 +14,12 @@ * limitations under the License. */ +#include "src/runtime/kernel/opencl/kernel/resize.h" +#include #include #include -#include #include "include/errorcode.h" #include "src/kernel_registry.h" -#include "src/runtime/kernel/opencl/kernel/resize.h" #include "src/runtime/kernel/opencl/cl/resize.cl.inc" using mindspore::kernel::KERNEL_ARCH::kGPU; @@ -46,9 +46,9 @@ int ResizeOpenCLKernel::Init() { return RET_PARAM_INVALID; } std::string kernel_name = "resize"; - if (resize_param->method_ == schema::ResizeMethod_BILINEAR) { + if (resize_param->method_ == schema::ResizeMethod_LINEAR) { kernel_name += "_bilinear"; - } else if (resize_param->method_ == schema::ResizeMethod_NEAREST_NEIGHBOR) { + } else if (resize_param->method_ == schema::ResizeMethod_NEAREST) { kernel_name += "_nearest_neighbor"; } else { MS_LOG(ERROR) << "unsupported resize method:" << resize_param->method_; diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/resize_bilinear_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/resize_bilinear_fp32_tests.cc index 2bd9a93664..35b9effe61 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/resize_bilinear_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/resize_bilinear_fp32_tests.cc @@ -14,11 +14,11 @@ * limitations under the License. */ #include +#include "common/common_test.h" +#include "mindspore/lite/src/kernel_registry.h" #include "mindspore/lite/src/lite_kernel.h" #include "mindspore/lite/src/tensor.h" -#include "common/common_test.h" #include "nnacl/resize_parameter.h" -#include "mindspore/lite/src/kernel_registry.h" #include "schema/ops_generated.h" using mindspore::schema::Format_NHWC; @@ -62,7 +62,7 @@ void TestResizeBilinearFp32::Prepare(const std::vector &input_shape, const out_tensor_.SetData(output_data); ResizeParameter param_ = { - {}, static_cast(schema::ResizeMethod_BILINEAR), output_shape[1], output_shape[2], align_corners}; + {}, static_cast(schema::ResizeMethod_LINEAR), output_shape[1], output_shape[2], align_corners}; desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Resize}; ctx_ = lite::InnerContext(); ctx_.thread_num_ = thread_num; diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/resize_nearest_neighbor_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/resize_nearest_neighbor_fp32_tests.cc index ed579145d5..30eec684a5 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/resize_nearest_neighbor_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/resize_nearest_neighbor_fp32_tests.cc @@ -16,7 +16,7 @@ #include #include "common/common_test.h" #include "nnacl/resize_parameter.h" -#include "mindspore/lite/src/kernel_registry.h" +#include "src/kernel_registry.h" namespace mindspore { @@ -57,7 +57,7 @@ void TestResizeNearestNeighborFp32::Prepare(const std::vector &input_shape, out_tensor_.SetData(output_data); ResizeParameter param_ = { - {}, static_cast(schema::ResizeMethod_NEAREST_NEIGHBOR), output_shape[1], output_shape[2], align_corners}; + {}, static_cast(schema::ResizeMethod_NEAREST), output_shape[1], output_shape[2], align_corners}; desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Resize}; ctx_ = lite::InnerContext(); ctx_.thread_num_ = thread_num; diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/resize_bilinear_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/resize_bilinear_int8_tests.cc index 076ceb1ad6..a2c7bd2deb 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/resize_bilinear_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/resize_bilinear_int8_tests.cc @@ -19,7 +19,7 @@ #include "include/context.h" #include "src/tensor.h" #include "common/common_test.h" -#include "mindspore/lite/src/kernel_registry.h" +#include "src/kernel_registry.h" #include "nnacl/int8/resize.h" namespace mindspore { @@ -68,7 +68,7 @@ void TestResizeBilinearInt8::Prepare(const std::vector &in_shape, const std inputs.push_back(&in_tensor); outputs.push_back(&out_tensor); - param_.method_ = static_cast(schema::ResizeMethod_BILINEAR); + param_.method_ = static_cast(schema::ResizeMethod_LINEAR); param_.new_width_ = out_shape[2]; param_.new_height_ = out_shape[1]; param_.align_corners_ = align_corners; diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/resize_nearest_neighbor_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/resize_nearest_neighbor_int8_tests.cc index 2b43d037be..a14d6dc3ab 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/resize_nearest_neighbor_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/resize_nearest_neighbor_int8_tests.cc @@ -19,7 +19,7 @@ #include "include/context.h" #include "src/tensor.h" #include "common/common_test.h" -#include "mindspore/lite/src/kernel_registry.h" +#include "src/kernel_registry.h" #include "nnacl/int8/resize.h" namespace mindspore { @@ -63,7 +63,7 @@ void TestResizeNearestNeighborInt8::Prepare(const std::vector &in_shape, co inputs.push_back(&in_tensor); outputs.push_back(&out_tensor); - param_.method_ = static_cast(schema::ResizeMethod_NEAREST_NEIGHBOR); + param_.method_ = static_cast(schema::ResizeMethod_NEAREST); param_.new_width_ = out_shape[2]; param_.new_height_ = out_shape[1]; param_.align_corners_ = align_corners; diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/resize_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/resize_tests.cc index 9555e92417..ae6dd01ebe 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/resize_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/resize_tests.cc @@ -15,13 +15,13 @@ */ #include #include -#include "src/common/log_adapter.h" #include "common/common_test.h" -#include "mindspore/lite/src/common/file_utils.h" -#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" -#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" -#include "mindspore/lite/src/runtime/kernel/opencl/kernel/resize.h" -#include "mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.h" +#include "src/common/file_utils.h" +#include "src/common/log_adapter.h" +#include "src/runtime/kernel/opencl/kernel/resize.h" +#include "src/runtime/kernel/opencl/subgraph_opencl_kernel.h" +#include "src/runtime/opencl/opencl_runtime.h" +#include "test/ut/src/runtime/kernel/opencl/utils_tests.h" namespace mindspore { class TestResizeOpenCL : public mindspore::CommonTest { @@ -119,7 +119,7 @@ TEST_F(TestResizeOpenCL, ResizeBilinearFp32) { std::vector input_data = {0.0f, 1.0f, 2.0f, 3.0f}; std::vector output_data = {0.0f, 0.5f, 1.0f, 1.0f, 1.0f, 1.5f, 2.0f, 2.0f, 2.0f, 2.5f, 3.0f, 3.0f, 2.0f, 2.5f, 3.0f, 3.0f}; - RunTestCaseResize(shape, input_data.data(), output_data.data(), false, schema::ResizeMethod_BILINEAR, align_corners); + RunTestCaseResize(shape, input_data.data(), output_data.data(), false, schema::ResizeMethod_LINEAR, align_corners); } TEST_F(TestResizeOpenCL, ResizeBilinearFp16) { @@ -134,7 +134,7 @@ TEST_F(TestResizeOpenCL, ResizeBilinearFp16) { std::vector input_data = {0.0f, 1.0f, 2.0f, 3.0f}; std::vector output_data = {0.0f, 0.5f, 1.0f, 1.0f, 1.0f, 1.5f, 2.0f, 2.0f, 2.0f, 2.5f, 3.0f, 3.0f, 2.0f, 2.5f, 3.0f, 3.0f}; - RunTestCaseResize(shape, input_data.data(), output_data.data(), true, schema::ResizeMethod_BILINEAR, align_corners); + RunTestCaseResize(shape, input_data.data(), output_data.data(), true, schema::ResizeMethod_LINEAR, align_corners); } TEST_F(TestResizeOpenCL, ResizeBilinearAlignFp32) { @@ -148,7 +148,7 @@ TEST_F(TestResizeOpenCL, ResizeBilinearAlignFp32) { std::vector shape = {n, h, w, oh, ow, c}; std::vector input_data = {0.0f, 1.0f, 2.0f, 3.0f}; std::vector output_data = {0.0f, 0.5f, 1.0f, 1.0f, 1.5f, 2.0f, 2.0f, 2.5f, 3.0f}; - RunTestCaseResize(shape, input_data.data(), output_data.data(), false, schema::ResizeMethod_BILINEAR, align_corners); + RunTestCaseResize(shape, input_data.data(), output_data.data(), false, schema::ResizeMethod_LINEAR, align_corners); } TEST_F(TestResizeOpenCL, ResizeNearestNeighborFp32) { @@ -163,8 +163,7 @@ TEST_F(TestResizeOpenCL, ResizeNearestNeighborFp32) { std::vector input_data = {0.0f, 1.0f, 2.0f, 3.0f}; std::vector output_data = {0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f, 1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f, 2.0f, 2.0f, 3.0f, 3.0f}; - RunTestCaseResize(shape, input_data.data(), output_data.data(), false, schema::ResizeMethod_NEAREST_NEIGHBOR, - align_corners); + RunTestCaseResize(shape, input_data.data(), output_data.data(), false, schema::ResizeMethod_NEAREST, align_corners); } TEST_F(TestResizeOpenCL, ResizeNearestNeighborFp16) { @@ -179,7 +178,6 @@ TEST_F(TestResizeOpenCL, ResizeNearestNeighborFp16) { std::vector input_data = {0.0f, 1.0f, 2.0f, 3.0f}; std::vector output_data = {0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f, 1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f, 2.0f, 2.0f, 3.0f, 3.0f}; - RunTestCaseResize(shape, input_data.data(), output_data.data(), true, schema::ResizeMethod_NEAREST_NEIGHBOR, - align_corners); + RunTestCaseResize(shape, input_data.data(), output_data.data(), true, schema::ResizeMethod_NEAREST, align_corners); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_resize_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_resize_parser_test.cc index cfff0c88ae..960ea20913 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_resize_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_resize_parser_test.cc @@ -40,7 +40,7 @@ TEST_F(TestTfliteParserResizeNN, AttrValue) { ASSERT_EQ(val->newWidth, 100); ASSERT_EQ(val->format, schema::Format_NHWC); ASSERT_EQ(val->preserveAspectRatio, false); - ASSERT_EQ(val->method, schema::ResizeMethod_NEAREST_NEIGHBOR); + ASSERT_EQ(val->method, schema::ResizeMethod_NEAREST); } class TestTfliteParserResizeBilinear : public TestTfliteParser { @@ -64,7 +64,7 @@ TEST_F(TestTfliteParserResizeBilinear, AttrValue) { ASSERT_EQ(val->newWidth, 4); ASSERT_EQ(val->format, schema::Format_NHWC); ASSERT_EQ(val->preserveAspectRatio, false); - ASSERT_EQ(val->method, schema::ResizeMethod_BILINEAR); + ASSERT_EQ(val->method, schema::ResizeMethod_LINEAR); } } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_interp_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_interp_parser.cc index 87bdcc1ba3..be11c2e2c7 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_interp_parser.cc +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_interp_parser.cc @@ -57,7 +57,7 @@ STATUS CaffeInterpParser::Parse(const caffe::LayerParameter &proto, const caffe: attr->newWidth = width; } attr->alignCorners = true; - attr->method = schema::ResizeMethod_BILINEAR; + attr->method = schema::ResizeMethod_LINEAR; op->name = proto.name(); op->primitive->value.type = schema::PrimitiveType_Resize; diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc index 6bd54d597d..12bdde3af9 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc @@ -582,6 +582,94 @@ STATUS OnnxSignParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod return RET_OK; } +STATUS OnnxAndParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + MS_LOG(DEBUG) << "onnx AndParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + op->primitive->value.type = schema::PrimitiveType_LogicalAnd; + op->primitive->value.value = attr.release(); + return RET_OK; +} + +STATUS OnnxOrParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + MS_LOG(DEBUG) << "onnx OrParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + op->primitive->value.type = schema::PrimitiveType_LogicalOr; + op->primitive->value.value = attr.release(); + return RET_OK; +} + +STATUS OnnxNotParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + MS_LOG(DEBUG) << "onnx NotParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + op->primitive->value.type = schema::PrimitiveType_LogicalNot; + op->primitive->value.value = attr.release(); + return RET_OK; +} + +STATUS OnnxRoundParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + MS_LOG(DEBUG) << "onnx RoundParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + op->primitive->value.type = schema::PrimitiveType_Round; + op->primitive->value.value = attr.release(); + return RET_OK; +} OnnxNodeRegistrar g_onnxAddParser("Add", new OnnxAddParser()); OnnxNodeRegistrar g_onnxInt8AddParser("Int8Add", new OnnxAddParser()); OnnxNodeRegistrar g_onnxSubParser("Sub", new OnnxSubParser()); @@ -608,5 +696,9 @@ OnnxNodeRegistrar g_onnxAtanParser("Atan", new OnnxAtanParser()); OnnxNodeRegistrar g_onnxAsinParser("Asin", new OnnxAsinParser()); OnnxNodeRegistrar g_onnxTanhParser("Tanh", new OnnxTanhParser()); OnnxNodeRegistrar g_onnxSignParser("Sign", new OnnxTanhParser()); +OnnxNodeRegistrar g_onnxAndParser("And", new OnnxAndParser()); +OnnxNodeRegistrar g_onnxOrParser("Or", new OnnxOrParser()); +OnnxNodeRegistrar g_onnxNotParser("Not", new OnnxNotParser()); +OnnxNodeRegistrar g_onnxRoundParser("Round", new OnnxRoundParser()); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h index dd1fec8083..efcae3e3ed 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h @@ -171,6 +171,30 @@ class OnnxSignParser : public OnnxNodeParser { OnnxSignParser() : OnnxNodeParser("Sign") {} STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; }; + +class OnnxAndParser : public OnnxNodeParser { + public: + OnnxAndParser() : OnnxNodeParser("And") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxOrParser : public OnnxNodeParser { + public: + OnnxOrParser() : OnnxNodeParser("Or") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxNotParser : public OnnxNodeParser { + public: + OnnxNotParser() : OnnxNodeParser("Not") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxRoundParser : public OnnxNodeParser { + public: + OnnxRoundParser() : OnnxNodeParser("Round") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; } // namespace lite } // namespace mindspore #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARITHMETIC_OPREATION_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc index 7a2a474eff..de35c328e8 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc @@ -15,9 +15,9 @@ */ #include "tools/converter/parser/onnx/onnx_conv_parser.h" -#include -#include #include +#include +#include namespace mindspore { namespace lite { @@ -176,9 +176,6 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod MS_LOG(ERROR) << "Convert Convolution to Depthwise failed"; return RET_ERROR; } - } else if (attr->group != 1) { - MS_LOG(ERROR) << "group conv hasn't supported"; - return RET_NOT_SUPPORT; } else { op->primitive->value.type = schema::PrimitiveType_Conv2D; op->primitive->value.value = attr.release(); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_identity_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_identity_parser.cc new file mode 100644 index 0000000000..c414c64ae9 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_identity_parser.cc @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/onnx/onnx_identity_parser.h" +#include +#include + +namespace mindspore { +namespace lite { +STATUS OnnxIdentityParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + MS_LOG(DEBUG) << "onnx IdentityParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + op->primitive->value.type = schema::PrimitiveType_Identity; + op->primitive->value.value = attr.release(); + return RET_OK; +} + +OnnxNodeRegistrar g_onnxIdentityParser("Identity", new OnnxIdentityParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_identity_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_identity_parser.h new file mode 100644 index 0000000000..2b10c266d2 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_identity_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_IDENTITY_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_IDENTITY_PARSER_H + +#include "tools/converter/parser/onnx/onnx_node_parser.h" +#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxIdentityParser : public OnnxNodeParser { + public: + OnnxIdentityParser() : OnnxNodeParser("Identity") {} + + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_IDENTITY_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_instance_norm_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_instance_norm_parser.cc new file mode 100644 index 0000000000..34039fb6ac --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_instance_norm_parser.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/onnx/onnx_instance_norm_parser.h" +#include + +namespace mindspore { +namespace lite { +STATUS OnnxInstanceNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + MS_LOG(DEBUG) << "onnx InstanceNormParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + if (!onnx_node.attribute().empty()) { + auto onnx_node_attr = onnx_node.attribute().at(0); + if (onnx_node_attr.name() == "epsilon") { + attr->epsilon = onnx_node_attr.f(); + } + } + + op->primitive->value.type = schema::PrimitiveType_InstanceNorm; + op->primitive->value.value = attr.release(); + return RET_OK; +} + +OnnxNodeRegistrar g_onnxInstanceNormParser("InstanceNormalization", new OnnxInstanceNormParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_instance_norm_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_instance_norm_parser.h new file mode 100644 index 0000000000..924c345730 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_instance_norm_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_INSTANCE_NORM_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_INSTANCE_NORM_PARSER_H + +#include "tools/converter/parser/onnx/onnx_node_parser.h" +#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxInstanceNormParser : public OnnxNodeParser { + public: + OnnxInstanceNormParser() : OnnxNodeParser("InstanceNorm") {} + + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_INSTANCE_NORM_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc index 59f0a6b5c6..259fff072b 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -15,12 +15,12 @@ */ #include "tools/converter/parser/onnx/onnx_model_parser.h" +#include #include #include -#include #include -#include "tools/common/graph_util.h" #include "src/common/utils.h" +#include "tools/common/graph_util.h" #include "tools/common/protobuf_utils.h" namespace mindspore { @@ -36,7 +36,8 @@ static const std::unordered_map TYPE_MAP = { {onnx::TensorProto_DataType_UINT32, mindspore::kNumberTypeUInt32}, {onnx::TensorProto_DataType_INT64, mindspore::kNumberTypeInt64}, {onnx::TensorProto_DataType_FLOAT16, mindspore::kNumberTypeFloat16}, - {onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32}}; + {onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32}, + {onnx::TensorProto_DataType_BOOL, mindspore::kNumberTypeBool}}; TypeId OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type) { auto iter = TYPE_MAP.find(onnx_type); @@ -161,9 +162,13 @@ STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache) { for (const auto &output_value : onnx_graph.output()) { int index; - const auto status = AddValueInfo(output_value, output_value.name(), OP_OUTPUT, tensor_cache, &index); - if (status != RET_OK) { - return status; + if (tensor_cache->FindTensor(output_value.name()) != -1) { + index = tensor_cache->FindTensor(output_value.name()); + } else { + const auto status = AddValueInfo(output_value, output_value.name(), OP_OUTPUT, tensor_cache, &index); + if (status != RET_OK) { + return status; + } } graph->outputIndex.emplace_back(index); MS_LOG(DEBUG) << "output_value name: " << output_value.name() << ", graph output index: " << index; @@ -250,7 +255,8 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *dst_op, schema::TensorT *dst_tensor, - TensorCache *tensor_cache, const QuantType &quantType) { + TensorCache *tensor_cache, const QuantType &quantType, + schema::MetaGraphT *dst_graph) { // change op_type() to name(), that is unique static bool interrupt = false; dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0); @@ -260,23 +266,34 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, << onnx_node.input_size(); // get the real op type SetOpQuantParams(onnx_graph, onnx_node, dst_op, dst_tensor, tensor_cache); - auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_node.op_type()); - if (node_parser == nullptr || interrupt) { + if (onnx_node.op_type() == "Loop") { + NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type()); interrupt = true; - if (node_parser == nullptr) { - NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type()); - } return RET_NOT_FIND_OP; - } - auto status = node_parser->Parse(onnx_graph, onnx_node, dst_op); - if (status != RET_OK) { - interrupt = true; - if (status == RET_NOT_SUPPORT) { - NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type()); - } else { - MS_LOG(ERROR) << "parser onnx node " << onnx_node.op_type() << " attr failed"; + int status = ParseLoopAttr(dst_op, onnx_node, quantType, dst_graph); + if (status != RET_OK || interrupt) { + interrupt = true; + return status; + } + } else { + auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_node.op_type()); + if (node_parser == nullptr || interrupt) { + interrupt = true; + if (node_parser == nullptr) { + NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type()); + } + return RET_NOT_FIND_OP; + } + auto status = node_parser->Parse(onnx_graph, onnx_node, dst_op); + if (status != RET_OK) { + interrupt = true; + if (status == RET_NOT_FIND_OP) { + NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type()); + } else { + MS_LOG(ERROR) << "parser onnx node " << onnx_node.op_type() << " attr failed"; + } + return status; } - return status; } // set op input index std::vector node_inputs; @@ -366,7 +383,7 @@ STATUS OnnxModelParser::SetOpInputIndex(const std::vector &node_inputs, const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) { for (const auto &onnx_node_input : node_inputs) { if (onnx_node_input != "") { - auto index = tensor_cache->FindTensor(onnx_node_input); + int index = tensor_cache->FindTensor(onnx_node_input); if (index < 0) { MS_LOG(ERROR) << "input " << onnx_node_input << " of node " << onnx_node.name() << " can't be found"; return RET_ERROR; @@ -428,6 +445,9 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v } for (size_t i = 0; i < data_count; ++i) { if (in_data[i] > static_cast(INT32_MAX) || in_data[i] < static_cast(INT32_MIN)) { + if (llabs(in_data[i]) == INT64_MAX || in_data[i] == INT64_MIN) { + buffer[i] = in_data[i] > 0 ? INT32_MAX : INT32_MIN; + } MS_LOG(ERROR) << "int64 data " << in_data[i] << "too big to fit into int32"; return RET_ERROR; } else { @@ -438,6 +458,7 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v break; case kNumberTypeUInt8: case kNumberTypeInt8: + case kNumberTypeBool: data_size = data_count * sizeof(uint8_t); tensor_data = onnx_const_value.raw_data().data(); break; @@ -446,7 +467,7 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v return RET_ERROR; } tensor->data.resize(data_size); - if (memcpy_s(static_cast(tensor->data.data()), data_size, tensor_data, data_size) != 0) { + if (data_size != 0 && memcpy_s(static_cast(tensor->data.data()), data_size, tensor_data, data_size) != 0) { MS_LOG(ERROR) << "memcpy_s failed"; return RET_ERROR; } @@ -475,30 +496,39 @@ void OnnxModelParser::FindGraphInputAndConst(const onnx::GraphProto &onnx_graph) } } -schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile, - const QuantType &quantType) { - int status = ValidateFileStr(modelFile, ".onnx"); - if (status != RET_OK) { - MS_LOG(ERROR) << "Input illegal: modelFile must be *.onnx"; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return nullptr; - } - - onnx::ModelProto onnx_model; - status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), &onnx_model); - if (status != RET_OK) { - MS_LOG(ERROR) << "Read onnx model file failed, model path: " << modelFile; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return nullptr; - } - const onnx::GraphProto &onnx_graph = onnx_model.graph(); - MS_LOG(INFO) << "model producer name: " << onnx_model.producer_name() << ", graph name: " << onnx_graph.name(); +STATUS OnnxModelParser::ParseLoopAttr(schema::CNodeT *dst_op, const onnx::NodeProto &onnx_node, + const QuantType &quantType, schema::MetaGraphT *dst_graph) { + MS_LOG(DEBUG) << "onnx LoopParser"; + if (dst_op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + dst_op->primitive = std::make_unique(); + if (dst_op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + attr->subGraphIndex = subGraphNum; + auto sub_graph = std::make_unique(); + sub_graph.reset(ParseGraph(onnx_node.attribute().at(0).g(), quantType)); + dst_graph->subGraph.push_back(std::move(sub_graph)); + subGraphNum += 1; + dst_op->primitive->value.type = schema::PrimitiveType_Loop; + dst_op->primitive->value.value = attr.release(); + return RET_OK; +} +schema::MetaGraphT *OnnxModelParser::ParseGraph(const onnx::GraphProto &onnx_graph, const QuantType &quantType) { TensorCache tensor_cache; // dst_graph->name = onnx_graph.name(); // this is not used // find out input names and const names FindGraphInputAndConst(onnx_graph); // set const tensor - status = SetGraphConstTensor(onnx_graph, &tensor_cache); + int status = SetGraphConstTensor(onnx_graph, &tensor_cache); if (status != RET_OK) { MS_LOG(ERROR) << "SetGraphConstTensor failed"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); @@ -512,13 +542,7 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); return nullptr; } - // init onnx model graph output tensor - status = SetGraphOutputTensor(onnx_graph, dst_graph.get(), &tensor_cache); - if (status != RET_OK) { - MS_LOG(ERROR) << "SetGraphOutputTensor failed"; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return nullptr; - } + // init op node input/output tensor, and dst_op attr NoSupportOp::GetInstance()->SetFmkType("ONNX"); for (const auto &onnx_node : onnx_graph.node()) { @@ -544,7 +568,8 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con std::unique_ptr dst_op = std::make_unique(); std::unique_ptr dst_tensor = std::make_unique(); - status_node = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache, quantType); + status_node = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache, quantType, + dst_graph.get()); if (status_node != RET_OK) { status = (status == RET_OK ? status_node : status); continue; @@ -558,9 +583,42 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con } return nullptr; } + // init onnx model graph output tensor + status = SetGraphOutputTensor(onnx_graph, dst_graph.get(), &tensor_cache); + if (status != RET_OK) { + MS_LOG(ERROR) << "SetGraphOutputTensor failed"; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); + return nullptr; + } SetAllTensors(tensor_cache, dst_graph.get()); - dst_graph->name = GetModelName(modelFile); return dst_graph.release(); } +schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile, + const QuantType &quantType) { + int status = ValidateFileStr(modelFile, ".onnx"); + if (status != RET_OK) { + MS_LOG(ERROR) << "Input illegal: modelFile must be *.onnx"; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); + return nullptr; + } + + onnx::ModelProto onnx_model; + status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), &onnx_model); + if (status != RET_OK) { + MS_LOG(ERROR) << "Read onnx model file failed, model path: " << modelFile; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); + return nullptr; + } + const onnx::GraphProto &onnx_graph = onnx_model.graph(); + MS_LOG(INFO) << "model producer name: " << onnx_model.producer_name() << ", graph name: " << onnx_graph.name(); + + schema::MetaGraphT *dst_graph = ParseGraph(onnx_graph, quantType); + if (dst_graph == nullptr) { + return nullptr; + } + dst_graph->name = GetModelName(modelFile); + return dst_graph; +} + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h index ec3adcb0ce..d5f3b95b97 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h @@ -26,6 +26,7 @@ #include #include #include +#include #include "securec/include/securec.h" #include "tools/converter/model_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" @@ -40,6 +41,7 @@ class OnnxModelParser : public ModelParser { virtual ~OnnxModelParser(); + schema::MetaGraphT *ParseGraph(const onnx::GraphProto &graph, const QuantType &quantType = QuantType_QUANT_NONE); schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType = QuantType_QUANT_NONE) override; @@ -62,7 +64,7 @@ class OnnxModelParser : public ModelParser { STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *dst_op, schema::TensorT *dst_tensor, TensorCache *tensor_cache, - const QuantType &quantType); + const QuantType &quantType, schema::MetaGraphT *dst_graph); void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::MetaGraphT *graph, TensorCache *tensor_cache, const QuantType &quant_type); @@ -86,9 +88,13 @@ class OnnxModelParser : public ModelParser { void FindGraphInputAndConst(const onnx::GraphProto &onnx_graph); + STATUS ParseLoopAttr(schema::CNodeT *dst_op, const onnx::NodeProto &onnx_node, const QuantType &quantType, + schema::MetaGraphT *dst_graph); + private: - std::vector graphInputNames; - std::vector graphConstNames; + std::vector graphInputNames; + std::vector graphConstNames; + int subGraphNum = 0; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_non_max_suppression_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_non_max_suppression_parser.cc new file mode 100644 index 0000000000..f821451830 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_non_max_suppression_parser.cc @@ -0,0 +1,80 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/onnx/onnx_non_max_suppression_parser.h" +#include + +namespace mindspore { +namespace lite { +STATUS OnnxNonMaxSuppressionParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + MS_LOG(DEBUG) << "onnx EluParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + if (onnx_node.input_size() > 2) { + auto it = std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(), + [&](const onnx::TensorProto &it) { return it.name() == onnx_node.input(2); }); + if (it != onnx_graph.initializer().end()) { + attr->maxOutBoxPerClass = it->int64_data(0); + } + } + + if (onnx_node.input_size() > 3) { + auto it = std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(), + [&](const onnx::TensorProto &it) { return it.name() == onnx_node.input(3); }); + if (it != onnx_graph.initializer().end()) { + attr->iouThreshold = it->float_data(0); + } + } + + if (onnx_node.input_size() > 4) { + auto it = std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(), + [&](const onnx::TensorProto &it) { return it.name() == onnx_node.input(4); }); + if (it != onnx_graph.initializer().end()) { + attr->scoreThreshold = it->float_data(0); + } + } + + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "center_point_box") { + if (onnx_node_attr.has_i()) { + attr->centerPointBox = onnx_node_attr.i(); + } + } + } + + op->primitive->value.type = schema::PrimitiveType_Elu; + op->primitive->value.value = attr.release(); + return RET_OK; +} + +OnnxNodeRegistrar g_onnxNonMaxSuppressionParser("NonMaxSuppression", new OnnxNonMaxSuppressionParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_non_max_suppression_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_non_max_suppression_parser.h new file mode 100644 index 0000000000..46f3a7949b --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_non_max_suppression_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NON_MAX_SUPPRESSION_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NON_MAX_SUPPRESSION_PARSER_H + +#include "tools/converter/parser/onnx/onnx_node_parser.h" +#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxNonMaxSuppressionParser : public OnnxNodeParser { + public: + OnnxNonMaxSuppressionParser() : OnnxNodeParser("NonMaxSuppression") {} + + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NON_MAX_SUPPRESSION_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_resize_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_resize_parser.cc new file mode 100644 index 0000000000..7ca5db4f87 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_resize_parser.cc @@ -0,0 +1,95 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/onnx/onnx_resize_parser.h" +#include +#include +#include +#include + +namespace mindspore { +namespace lite { +STATUS OnnxResizeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + MS_LOG(DEBUG) << "onnx ResizeParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + attr->format = schema::Format_NCHW; + attr->nearestMode = schema::NearestMode_ROUND_HALF_DOWN; + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "coordinate_transformation_mode") { + attr->coordinateTransformMode = [&]() { + std::map transform_map = { + {"half_pixel", schema::CoordinateTransformMode_HALF_PIXEL}, + {"pytorch_half_pixel", schema::CoordinateTransformMode_PYTORCH_HALF_PIXEL}, + {"align_corners", schema::CoordinateTransformMode_ALIGN_CORNERS}, + {"asymmetric", schema::CoordinateTransformMode_ASYMMETRIC}, + {"tf_half_pixel_for_nn", schema::CoordinateTransformMode_TF_HALF_PIXEL}, + {"tf_crop_and_resize", schema::CoordinateTransformMode_TF_CROP_AND_RESIZE}, + }; + return transform_map[onnx_node_attr.strings(0)]; + }(); + } else if (attribute_name == "cubic_coeff_a") { + attr->cubicCoeff = onnx_node_attr.f(); + } else if (attribute_name == "exclude_outside") { + attr->excludeOutside = onnx_node_attr.i(); + } else if (attribute_name == "extrapolation_value") { + attr->extrapolationValue = onnx_node_attr.f(); + } else if (attribute_name == "mode") { + attr->method = [&]() { + std::map resize_mode = { + {"nearest", schema::ResizeMethod_NEAREST}, + {"linear", schema::ResizeMethod_LINEAR}, + {"cubic", schema::ResizeMethod_CUBIC}, + }; + return resize_mode[onnx_node_attr.strings(0)]; + }(); + } else if (attribute_name == "nearest_mode") { + attr->nearestMode = [&]() { + std::map nearest_mode = { + {"round_prefer_floor", schema::NearestMode_ROUND_HALF_DOWN}, + {"round_prefer_ceil", schema::NearestMode_ROUND_HALF_UP}, + {"floor", schema::NearestMode_FLOOR}, + {"ceil", schema::NearestMode_CEIL}, + }; + return nearest_mode[onnx_node_attr.strings(0)]; + }(); + } + } + + op->primitive->value.type = schema::PrimitiveType_Resize; + op->primitive->value.value = attr.release(); + return RET_OK; +} + +OnnxNodeRegistrar g_onnxResizeParser("Resize", new OnnxResizeParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_resize_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_resize_parser.h new file mode 100644 index 0000000000..8925ba0fcc --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_resize_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_RESIZE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_RESIZE_PARSER_H + +#include "tools/converter/parser/onnx/onnx_node_parser.h" +#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxResizeParser : public OnnxNodeParser { + public: + OnnxResizeParser() : OnnxNodeParser("Resize") {} + + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_RESIZE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_upsample_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_upsample_parser.cc index b160d1372e..562f6963c4 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_upsample_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_upsample_parser.cc @@ -14,8 +14,8 @@ * limitations under the License. */ -#include #include "tools/converter/parser/onnx/onnx_upsample_parser.h" +#include namespace mindspore { namespace lite { @@ -42,9 +42,9 @@ STATUS OnnxUpsampleParser::Parse(const onnx::GraphProto &onnx_graph, const onnx: const auto &attribute_name = onnx_node_attr.name(); if (attribute_name == "mode") { if ("nearest" == onnx_node_attr.s()) { - attr->method = schema::ResizeMethod_NEAREST_NEIGHBOR; + attr->method = schema::ResizeMethod_NEAREST; } else if ("bilinear" == onnx_node_attr.s()) { - attr->method = schema::ResizeMethod_BILINEAR; + attr->method = schema::ResizeMethod_LINEAR; } else { MS_LOG(ERROR) << "Resize do not support upsample mode"; return RET_ERROR; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.cc index 630fdaf9ff..fcf73d0322 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.cc @@ -15,9 +15,9 @@ */ #include "tools/converter/parser/tflite/tflite_custom_parser.h" -#include -#include #include +#include +#include #include "flatbuffers/flatbuffers.h" #include "flatbuffers/flexbuffers.h" @@ -206,6 +206,8 @@ STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni status = ExtractFeatures(custom_attr, op, tflite_op); } else if (custom_type == "AudioSpectrogram") { status = AudioSpectrogram(custom_attr, op, tflite_op); + } else if (custom_type == "Mfcc") { + status = Mfcc(custom_attr, op, tflite_op); } else if (custom_type == "FlexRFFT") { status = Rfft(custom_attr, op, tflite_op, tflite_model); } else if (custom_type == "FlexReal") { diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.cc index ad16c51833..9c2080284f 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.cc @@ -15,10 +15,10 @@ */ #include "tools/converter/parser/tflite/tflite_resize_parser.h" -#include +#include #include #include -#include +#include namespace mindspore { namespace lite { @@ -39,7 +39,7 @@ STATUS TfliteResizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni MS_LOG(ERROR) << "new op failed"; return RET_NULL_PTR; } - + attr->coordinateTransformMode = schema::CoordinateTransformMode_COMMON; std::vector node_name_str; Split(op->name.data(), &node_name_str, "-"); const char *node_name = node_name_str.data()->c_str(); @@ -50,8 +50,16 @@ STATUS TfliteResizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; return RET_NULL_PTR; } - attr->alignCorners = tfliteAttr->align_corners; - attr->method = schema::ResizeMethod_BILINEAR; + if (tfliteAttr->align_corners) { + attr->alignCorners = tfliteAttr->align_corners; + attr->coordinateTransformMode = schema::CoordinateTransformMode_ALIGN_CORNERS; + } + if (tfliteAttr->half_pixel_centers) { + attr->coordinateTransformMode = (attr->coordinateTransformMode == schema::CoordinateTransformMode_COMMON + ? schema::CoordinateTransformMode_TF_HALF_PIXEL + : schema::CoordinateTransformMode_ALIGN_CORNERS_WITH_HALF_PIEXL); + } + attr->method = schema::ResizeMethod_LINEAR; } else if (std::strcmp(node_name, "NearestNeighbor") == 0) { MS_LOG(DEBUG) << "parse TfliteResizeNearestNeighborParser"; const auto &tfliteAttr = tflite_op->builtin_options.AsResizeNearestNeighborOptions(); @@ -59,8 +67,17 @@ STATUS TfliteResizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; return RET_NULL_PTR; } - attr->alignCorners = tfliteAttr->align_corners; - attr->method = schema::ResizeMethod_NEAREST_NEIGHBOR; + if (tfliteAttr->align_corners) { + attr->alignCorners = tfliteAttr->align_corners; + attr->coordinateTransformMode = schema::CoordinateTransformMode_ALIGN_CORNERS; + } + if (tfliteAttr->half_pixel_centers) { + attr->coordinateTransformMode = (attr->coordinateTransformMode == schema::CoordinateTransformMode_COMMON + ? schema::CoordinateTransformMode_TF_HALF_PIXEL + : schema::CoordinateTransformMode_ALIGN_CORNERS_WITH_HALF_PIEXL); + } + attr->method = schema::ResizeMethod_NEAREST; + attr->nearestMode = schema::NearestMode_NORMAL; } else { MS_LOG(ERROR) << "wrong resize type"; return RET_ERROR;