From eb13519e71fdef895430a150f870c56302043e6c Mon Sep 17 00:00:00 2001 From: gongdaguo Date: Thu, 17 Dec 2020 10:57:01 +0800 Subject: [PATCH] add mod parser, add ArgMinWithValue parser --- mindspore/lite/schema/model.fbs | 1 + mindspore/lite/schema/ops.fbs | 3 + mindspore/lite/src/ops/argmin.cc | 37 +++++++++- mindspore/lite/src/ops/argmin.h | 1 + mindspore/lite/src/ops/mod.cc | 70 +++++++++++++++++++ mindspore/lite/src/ops/mod.h | 42 +++++++++++ .../src/ops/populate/arithmetic_populate.cc | 1 + mindspore/lite/src/ops/primitive_c.cc | 7 ++ 8 files changed, 161 insertions(+), 1 deletion(-) create mode 100644 mindspore/lite/src/ops/mod.cc create mode 100644 mindspore/lite/src/ops/mod.h diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index 19737f823f..6a5d69724e 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -260,6 +260,7 @@ union PrimitiveType { SigmoidCrossEntropyWithLogitsGrad, Reciprocal, Merge, + Mod, } enum QuantType: int { diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 942d75ce3e..4d6fee03a6 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -890,6 +890,9 @@ table FloorDiv { table FloorMod { } +table Mod { +} + table L2Norm { axis: [int]; epsilon: float; diff --git a/mindspore/lite/src/ops/argmin.cc b/mindspore/lite/src/ops/argmin.cc index acd6272ae8..f7a4a72fdd 100644 --- a/mindspore/lite/src/ops/argmin.cc +++ b/mindspore/lite/src/ops/argmin.cc @@ -35,6 +35,36 @@ void ArgMin::SetTopK(int top_k) { this->primitive_->value.AsArgMin()->topK = top void ArgMin::SetKeepDims(bool keep_dims) { this->primitive_->value.AsArgMin()->keepDims = keep_dims; } void ArgMin::SetAxisType(int axis_type) { this->primitive_->value.AsArgMin()->axisType = axis_type; } +int ArgMin::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_ArgMin; + } + if (this->primitive_->value.type != schema::PrimitiveType_ArgMin) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::ArgMinT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + this->primitive_->value.value = attr; + if (prim.GetAttr("axis") != nullptr) { + attr->axis = static_cast(GetValue(prim.GetAttr("axis"))); + } + if (prim.GetAttr("keep_dims") != nullptr) { + attr->keepDims = static_cast(GetValue(prim.GetAttr("keep_dims"))); + } + } + return RET_OK; +} + #else int ArgMin::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { MS_ASSERT(nullptr != primitive); @@ -66,7 +96,7 @@ int ArgMin::InferShape(std::vector inputs_, std::vector kDoubleNum) { MS_LOG(ERROR) << "tensor number is error."; } output->set_format(input->format()); @@ -88,6 +118,11 @@ int ArgMin::InferShape(std::vector inputs_, std::vectorset_shape(output_shape); + if (outputs_.size() == kDoubleNum) { + outputs_.at(1)->set_format(input->format()); + outputs_.at(1)->set_data_type(input->data_type()); + outputs_.at(1)->set_shape(output_shape); + } return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/ops/argmin.h b/mindspore/lite/src/ops/argmin.h index d159f99959..4a1ab9af12 100644 --- a/mindspore/lite/src/ops/argmin.h +++ b/mindspore/lite/src/ops/argmin.h @@ -36,6 +36,7 @@ class ArgMin : public PrimitiveC { void SetTopK(int top_k); void SetKeepDims(bool keep_dims); void SetAxisType(int axis_type); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif diff --git a/mindspore/lite/src/ops/mod.cc b/mindspore/lite/src/ops/mod.cc new file mode 100644 index 0000000000..ebcaa6458d --- /dev/null +++ b/mindspore/lite/src/ops/mod.cc @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/mod.h" + +#ifndef PRIMITIVE_WRITEABLE +#include "src/ops/ops_register.h" +#endif + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE + +int Mod::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_Mod; + } + if (this->primitive_->value.type != schema::PrimitiveType_Mod) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + delete this->primitive_; + this->primitive_ = nullptr; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::ModT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + delete this->primitive_; + this->primitive_ = nullptr; + return RET_ERROR; + } + this->primitive_->value.value = attr; + } + return RET_OK; +} + +#else + +int Mod::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto val_offset = schema::CreateMod(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Mod, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} +PrimitiveC *ModCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry ModRegistry(schema::PrimitiveType_Mod, ModCreator); +#endif + +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/mod.h b/mindspore/lite/src/ops/mod.h new file mode 100644 index 0000000000..3a351e6889 --- /dev/null +++ b/mindspore/lite/src/ops/mod.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_MINDSPORE_LITE_C_OPS_MOD_H_ +#define LITE_MINDSPORE_LITE_C_OPS_MOD_H_ + +#include +#include +#include +#include "src/ops/arithmetic.h" + +namespace mindspore { +namespace lite { +class Mod : public Arithmetic { + public: + Mod() = default; + ~Mod() = default; +#ifdef PRIMITIVE_WRITEABLE + MS_DECLARE_PARENT(Mod, Arithmetic); + explicit Mod(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; +#else + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; +#endif +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_MINDSPORE_LITE_C_OPS_FLOOR_MOD_H_ diff --git a/mindspore/lite/src/ops/populate/arithmetic_populate.cc b/mindspore/lite/src/ops/populate/arithmetic_populate.cc index 74dffde9d7..d02a050859 100644 --- a/mindspore/lite/src/ops/populate/arithmetic_populate.cc +++ b/mindspore/lite/src/ops/populate/arithmetic_populate.cc @@ -67,6 +67,7 @@ Registry MaximumParameterRegistry(schema::PrimitiveType_Maximum, PopulateArithme Registry MinimumParameterRegistry(schema::PrimitiveType_Minimum, PopulateArithmetic); Registry FloorDivParameterRegistry(schema::PrimitiveType_FloorDiv, PopulateArithmetic); Registry FloorModParameterRegistry(schema::PrimitiveType_FloorMod, PopulateArithmetic); +Registry ModParameterRegistry(schema::PrimitiveType_Mod, PopulateArithmetic); Registry SquaredDifferenceParameterRegistry(schema::PrimitiveType_SquaredDifference, PopulateArithmetic); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index 3b3de99e8c..d4047cd7c2 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -101,6 +101,7 @@ #include "src/ops/logical_not.h" #include "src/ops/floor_div.h" #include "src/ops/floor_mod.h" +#include "src/ops/mod.h" #include "src/ops/equal.h" #include "src/ops/not_equal.h" #include "src/ops/less.h" @@ -597,6 +598,10 @@ std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std: return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "TopK") { return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "Mod") { + return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "ArgMinWithValue") { + return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Range") { return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Tile") { @@ -805,6 +810,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { return new (std::nothrow) FloorDiv(primitive); case schema::PrimitiveType_FloorMod: return new (std::nothrow) FloorMod(primitive); + case schema::PrimitiveType_Mod: + return new (std::nothrow) Mod(primitive); case schema::PrimitiveType_Equal: return new (std::nothrow) Equal(primitive); case schema::PrimitiveType_NotEqual: