| @@ -352,7 +352,7 @@ table FakeQuantWithMinMaxVars { | |||||
| } | } | ||||
| table BiasAdd { | table BiasAdd { | ||||
| axis: [int]; | |||||
| axis: [int]; // DEPRECATED | |||||
| } | } | ||||
| table ROIPooling { | table ROIPooling { | ||||
| @@ -24,10 +24,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| std::vector<int> BiasAdd::GetAxis() const { return this->primitive_->value.AsBiasAdd()->axis; } | |||||
| void BiasAdd::SetAxis(const std::vector<int> &axis) { this->primitive_->value.AsBiasAdd()->axis = axis; } | |||||
| int BiasAdd::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | int BiasAdd::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | ||||
| if (this->primitive_ == nullptr) { | if (this->primitive_ == nullptr) { | ||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | this->primitive_ = new (std::nothrow) schema::PrimitiveT; | ||||
| @@ -67,21 +63,11 @@ int BiasAdd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers | |||||
| MS_LOG(ERROR) << "value_as_BiasAdd return nullptr"; | MS_LOG(ERROR) << "value_as_BiasAdd return nullptr"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| std::vector<int32_t> axis; | |||||
| if (attr->axis() != nullptr) { | |||||
| for (int i = 0; i < static_cast<int>(attr->axis()->size()); i++) { | |||||
| axis.push_back(attr->axis()->data()[i]); | |||||
| } | |||||
| } | |||||
| auto val_offset = schema::CreateBiasAddDirect(*fbb, &axis); | |||||
| auto val_offset = schema::CreateBiasAddDirect(*fbb); | |||||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BiasAdd, val_offset.o); | auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BiasAdd, val_offset.o); | ||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| std::vector<int> BiasAdd::GetAxis() const { | |||||
| auto fb_vector = this->primitive_->value_as_BiasAdd()->axis(); | |||||
| return std::vector<int>(fb_vector->begin(), fb_vector->end()); | |||||
| } | |||||
| PrimitiveC *BiasAddCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<BiasAdd>(primitive); } | PrimitiveC *BiasAddCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<BiasAdd>(primitive); } | ||||
| Registry BiasAddRegistry(schema::PrimitiveType_BiasAdd, BiasAddCreator); | Registry BiasAddRegistry(schema::PrimitiveType_BiasAdd, BiasAddCreator); | ||||
| @@ -33,11 +33,9 @@ class BiasAdd : public PrimitiveC { | |||||
| MS_DECLARE_PARENT(BiasAdd, PrimitiveC); | MS_DECLARE_PARENT(BiasAdd, PrimitiveC); | ||||
| explicit BiasAdd(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | explicit BiasAdd(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | ||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | ||||
| void SetAxis(const std::vector<int> &axis); | |||||
| #else | #else | ||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | ||||
| #endif | #endif | ||||
| std::vector<int> GetAxis() const; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -38,22 +38,17 @@ STATUS OnnxUpsampleParser::Parse(const onnx::GraphProto &onnx_graph, const onnx: | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| attr->method = schema::ResizeMethod_NEAREST; | |||||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | for (const auto &onnx_node_attr : onnx_node.attribute()) { | ||||
| const auto &attribute_name = onnx_node_attr.name(); | const auto &attribute_name = onnx_node_attr.name(); | ||||
| if (attribute_name == "mode") { | if (attribute_name == "mode") { | ||||
| if ("nearest" == onnx_node_attr.s()) { | |||||
| attr->method = schema::ResizeMethod_NEAREST; | |||||
| } else if ("bilinear" == onnx_node_attr.s()) { | |||||
| attr->method = schema::ResizeMethod_LINEAR; | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Resize do not support upsample mode"; | |||||
| return RET_ERROR; | |||||
| if (onnx_node_attr.s() != "nearest" && onnx_node_attr.s() != "linear") { | |||||
| MS_LOG(ERROR) << "the upsample mode don't support now."; | |||||
| return RET_NOT_SUPPORT; | |||||
| } | } | ||||
| attr->method = onnx_node_attr.s() == "nearest" ? schema::ResizeMethod_NEAREST : schema::ResizeMethod_LINEAR; | |||||
| } | } | ||||
| } | } | ||||
| attr->newWidth = 1; | |||||
| attr->newHeight = 1; | |||||
| attr->alignCorners = false; | |||||
| op->primitive->value.type = schema::PrimitiveType_Resize; | op->primitive->value.type = schema::PrimitiveType_Resize; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| return RET_OK; | return RET_OK; | ||||