From f8267c821d23f5b5920a1eee0aa1a328c50e8b33 Mon Sep 17 00:00:00 2001 From: xuanyue Date: Fri, 4 Dec 2020 14:36:28 +0800 Subject: [PATCH] adjust onnx upsample --- mindspore/lite/schema/ops.fbs | 2 +- mindspore/lite/src/ops/bias_add.cc | 16 +--------------- mindspore/lite/src/ops/bias_add.h | 2 -- .../parser/onnx/onnx_upsample_parser.cc | 15 +++++---------- 4 files changed, 7 insertions(+), 28 deletions(-) diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 0c1fa797bf..c222730c6f 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -352,7 +352,7 @@ table FakeQuantWithMinMaxVars { } table BiasAdd { - axis: [int]; + axis: [int]; // DEPRECATED } table ROIPooling { diff --git a/mindspore/lite/src/ops/bias_add.cc b/mindspore/lite/src/ops/bias_add.cc index 13ecc02d4e..cdb0b56f36 100644 --- a/mindspore/lite/src/ops/bias_add.cc +++ b/mindspore/lite/src/ops/bias_add.cc @@ -24,10 +24,6 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -std::vector BiasAdd::GetAxis() const { return this->primitive_->value.AsBiasAdd()->axis; } - -void BiasAdd::SetAxis(const std::vector &axis) { this->primitive_->value.AsBiasAdd()->axis = axis; } - int BiasAdd::UnPackAttr(const Primitive &prim, const std::vector &inputs) { if (this->primitive_ == nullptr) { this->primitive_ = new (std::nothrow) schema::PrimitiveT; @@ -67,21 +63,11 @@ int BiasAdd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers MS_LOG(ERROR) << "value_as_BiasAdd return nullptr"; return RET_ERROR; } - std::vector axis; - if (attr->axis() != nullptr) { - for (int i = 0; i < static_cast(attr->axis()->size()); i++) { - axis.push_back(attr->axis()->data()[i]); - } - } - auto val_offset = schema::CreateBiasAddDirect(*fbb, &axis); + auto val_offset = schema::CreateBiasAddDirect(*fbb); auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BiasAdd, val_offset.o); fbb->Finish(prim_offset); return RET_OK; } -std::vector BiasAdd::GetAxis() const { - auto fb_vector = this->primitive_->value_as_BiasAdd()->axis(); - return std::vector(fb_vector->begin(), fb_vector->end()); -} PrimitiveC *BiasAddCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } Registry BiasAddRegistry(schema::PrimitiveType_BiasAdd, BiasAddCreator); diff --git a/mindspore/lite/src/ops/bias_add.h b/mindspore/lite/src/ops/bias_add.h index 218db90506..d1cdf391e2 100644 --- a/mindspore/lite/src/ops/bias_add.h +++ b/mindspore/lite/src/ops/bias_add.h @@ -33,11 +33,9 @@ class BiasAdd : public PrimitiveC { MS_DECLARE_PARENT(BiasAdd, PrimitiveC); explicit BiasAdd(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; - void SetAxis(const std::vector &axis); #else int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif - std::vector GetAxis() const; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_upsample_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_upsample_parser.cc index 7a2a3acad7..79893ec652 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_upsample_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_upsample_parser.cc @@ -38,22 +38,17 @@ STATUS OnnxUpsampleParser::Parse(const onnx::GraphProto &onnx_graph, const onnx: return RET_NULL_PTR; } + attr->method = schema::ResizeMethod_NEAREST; for (const auto &onnx_node_attr : onnx_node.attribute()) { const auto &attribute_name = onnx_node_attr.name(); if (attribute_name == "mode") { - if ("nearest" == onnx_node_attr.s()) { - attr->method = schema::ResizeMethod_NEAREST; - } else if ("bilinear" == onnx_node_attr.s()) { - attr->method = schema::ResizeMethod_LINEAR; - } else { - MS_LOG(ERROR) << "Resize do not support upsample mode"; - return RET_ERROR; + if (onnx_node_attr.s() != "nearest" && onnx_node_attr.s() != "linear") { + MS_LOG(ERROR) << "the upsample mode don't support now."; + return RET_NOT_SUPPORT; } + attr->method = onnx_node_attr.s() == "nearest" ? schema::ResizeMethod_NEAREST : schema::ResizeMethod_LINEAR; } } - attr->newWidth = 1; - attr->newHeight = 1; - attr->alignCorners = false; op->primitive->value.type = schema::PrimitiveType_Resize; op->primitive->value.value = attr.release(); return RET_OK;