diff --git a/mindspore/lite/src/ops/equal.cc b/mindspore/lite/src/ops/equal.cc index 29d91ec01a..ef2ebaeee6 100644 --- a/mindspore/lite/src/ops/equal.cc +++ b/mindspore/lite/src/ops/equal.cc @@ -22,7 +22,35 @@ namespace mindspore { namespace lite { -#ifndef PRIMITIVE_WRITEABLE +#ifdef PRIMITIVE_WRITEABLE +int Equal::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_Equal; + } + if (this->primitive_->value.type != schema::PrimitiveType_Equal) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::EqualT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + this->primitive_->value.value = attr; + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "primitive value is nullptr"; + return RET_ERROR; + } + } + return RET_OK; +} +#else int Equal::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { MS_ASSERT(nullptr != primitive); MS_ASSERT(nullptr != fbb); diff --git a/mindspore/lite/src/ops/equal.h b/mindspore/lite/src/ops/equal.h index 4eb0efe8ff..1dc8d3ab75 100644 --- a/mindspore/lite/src/ops/equal.h +++ b/mindspore/lite/src/ops/equal.h @@ -31,6 +31,7 @@ class Equal : public ArithmeticCompare { #ifdef PRIMITIVE_WRITEABLE MS_DECLARE_PARENT(Equal, ArithmeticCompare); explicit Equal(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {} + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif diff --git a/mindspore/lite/src/ops/gather_nd.cc b/mindspore/lite/src/ops/gather_nd.cc index 976c80d75a..f420e606f8 100644 --- a/mindspore/lite/src/ops/gather_nd.cc +++ b/mindspore/lite/src/ops/gather_nd.cc @@ -23,7 +23,36 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE - +int GatherNd::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_GatherNd; + } + if (this->primitive_->value.type != schema::PrimitiveType_GatherNd) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::GatherNdT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + if (prim.GetAttr("batchDims") != nullptr) { + attr->batchDims = static_cast(GetValue(prim.GetAttr("batchDims"))); + } + this->primitive_->value.value = attr; + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "primitive value is nullptr"; + return RET_ERROR; + } + } + return RET_OK; +} #else int GatherNd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { MS_ASSERT(nullptr != primitive); diff --git a/mindspore/lite/src/ops/gather_nd.h b/mindspore/lite/src/ops/gather_nd.h index f5fe94d0d8..7733050c53 100644 --- a/mindspore/lite/src/ops/gather_nd.h +++ b/mindspore/lite/src/ops/gather_nd.h @@ -32,6 +32,7 @@ class GatherNd : public PrimitiveC { #ifdef PRIMITIVE_WRITEABLE MS_DECLARE_PARENT(GatherNd, PrimitiveC); explicit GatherNd(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif diff --git a/mindspore/lite/src/ops/greater.cc b/mindspore/lite/src/ops/greater.cc index 2950e175a0..a90926b5be 100644 --- a/mindspore/lite/src/ops/greater.cc +++ b/mindspore/lite/src/ops/greater.cc @@ -22,8 +22,35 @@ namespace mindspore { namespace lite { -#ifndef PRIMITIVE_WRITEABLE - +#ifdef PRIMITIVE_WRITEABLE +int Greater::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_Greater; + } + if (this->primitive_->value.type != schema::PrimitiveType_Greater) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::GreaterT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + this->primitive_->value.value = attr; + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "primitive value is nullptr"; + return RET_ERROR; + } + } + return RET_OK; +} +#else int Greater::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { MS_ASSERT(nullptr != primitive); MS_ASSERT(nullptr != fbb); diff --git a/mindspore/lite/src/ops/greater.h b/mindspore/lite/src/ops/greater.h index c7de708ec2..ae7ef82b51 100644 --- a/mindspore/lite/src/ops/greater.h +++ b/mindspore/lite/src/ops/greater.h @@ -31,6 +31,7 @@ class Greater : public ArithmeticCompare { #ifdef PRIMITIVE_WRITEABLE MS_DECLARE_PARENT(Greater, ArithmeticCompare); explicit Greater(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {} + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index 870ef76b78..822ad1d5ab 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -593,6 +593,22 @@ std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std: return NewPrimitiveC
(prim, inputs, quantType); } else if (op_type == "Tanh") { return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "Equal") { + return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "TopK") { + return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "Range") { + return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "Tile") { + return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "GatherNd") { + return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "Square") { + return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "Sqrt") { + return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "Greater") { + return NewPrimitiveC(prim, inputs, quantType); #ifdef SUPPORT_TRAIN } else if (op_type == "SoftmaxCrossEntropyWithLogits") { return NewPrimitiveC(prim, inputs, quantType); @@ -621,8 +637,6 @@ std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std: return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "FusedBatchNormGrad") { return NewPrimitiveC(prim, inputs, quantType); - } else if (op_type == "Tile") { - return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "PowerGrad") { return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "SGD") { diff --git a/mindspore/lite/src/ops/range.cc b/mindspore/lite/src/ops/range.cc index 29f4d4644a..8014d62cd0 100644 --- a/mindspore/lite/src/ops/range.cc +++ b/mindspore/lite/src/ops/range.cc @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include "src/ops/range.h" #ifndef PRIMITIVE_WRITEABLE @@ -32,7 +33,43 @@ void Range::SetDType(int d_type) { this->primitive_->value.AsRange()->dType = d_ void Range::SetStart(int start) { this->primitive_->value.AsRange()->start = start; } void Range::SetLimit(int limit) { this->primitive_->value.AsRange()->limit = limit; } void Range::SetDelta(int delta) { this->primitive_->value.AsRange()->delta = delta; } - +int Range::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_Range; + } + if (this->primitive_->value.type != schema::PrimitiveType_Range) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::RangeT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + this->primitive_->value.value = attr; + attr->dType = 0; + if (prim.GetAttr("start") != nullptr) { + attr->start = static_cast(GetValue(prim.GetAttr("start"))); + } + if (prim.GetAttr("limit") != nullptr) { + attr->limit = static_cast(GetValue(prim.GetAttr("limit"))); + } + if (prim.GetAttr("delta") != nullptr) { + attr->delta = static_cast(GetValue(prim.GetAttr("delta"))); + } + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "primitive value is nullptr"; + return RET_ERROR; + } + } + return RET_OK; +} #else int Range::GetDType() const { return this->primitive_->value_as_Range()->dType(); } diff --git a/mindspore/lite/src/ops/range.h b/mindspore/lite/src/ops/range.h index 1bf8c8b882..8f1adafcc6 100644 --- a/mindspore/lite/src/ops/range.h +++ b/mindspore/lite/src/ops/range.h @@ -36,6 +36,7 @@ class Range : public PrimitiveC { void SetStart(int start); void SetLimit(int limit); void SetDelta(int delta); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif diff --git a/mindspore/lite/src/ops/sqrt.cc b/mindspore/lite/src/ops/sqrt.cc index e35ef381ac..099cad8ec9 100644 --- a/mindspore/lite/src/ops/sqrt.cc +++ b/mindspore/lite/src/ops/sqrt.cc @@ -23,6 +23,33 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE +int Sqrt::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_Sqrt; + } + if (this->primitive_->value.type != schema::PrimitiveType_Sqrt) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::SqrtT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + this->primitive_->value.value = attr; + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "primitive value is nullptr"; + return RET_ERROR; + } + } + return RET_OK; +} #else int Sqrt::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { MS_ASSERT(nullptr != primitive); diff --git a/mindspore/lite/src/ops/sqrt.h b/mindspore/lite/src/ops/sqrt.h index 12c7412baf..6f6ca94369 100644 --- a/mindspore/lite/src/ops/sqrt.h +++ b/mindspore/lite/src/ops/sqrt.h @@ -32,6 +32,7 @@ class Sqrt : public ArithmeticSelf { #ifdef PRIMITIVE_WRITEABLE MS_DECLARE_PARENT(Sqrt, ArithmeticSelf); explicit Sqrt(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif diff --git a/mindspore/lite/src/ops/square.cc b/mindspore/lite/src/ops/square.cc index e9f179dfbd..8a126389c1 100644 --- a/mindspore/lite/src/ops/square.cc +++ b/mindspore/lite/src/ops/square.cc @@ -23,6 +23,33 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE +int Square::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_Square; + } + if (this->primitive_->value.type != schema::PrimitiveType_Square) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::SquareT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + this->primitive_->value.value = attr; + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "primitive value is nullptr"; + return RET_ERROR; + } + } + return RET_OK; +} #else int Square::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { MS_ASSERT(nullptr != primitive); diff --git a/mindspore/lite/src/ops/square.h b/mindspore/lite/src/ops/square.h index 890087cc63..b86e2bc9bc 100644 --- a/mindspore/lite/src/ops/square.h +++ b/mindspore/lite/src/ops/square.h @@ -31,6 +31,7 @@ class Square : public ArithmeticSelf { #ifdef PRIMITIVE_WRITEABLE MS_DECLARE_PARENT(Square, ArithmeticSelf); explicit Square(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif diff --git a/mindspore/lite/src/ops/tile.cc b/mindspore/lite/src/ops/tile.cc index 319f0489be..90e86752eb 100644 --- a/mindspore/lite/src/ops/tile.cc +++ b/mindspore/lite/src/ops/tile.cc @@ -52,12 +52,6 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector &input MS_LOG(ERROR) << "new primitiveT value failed"; return RET_ERROR; } - if (prim.GetAttr("dims") == nullptr) { - MS_LOG(INFO) << "Tile's attr dims is set to default"; - attr->dims = {1}; - } else { - attr->dims = CastToInt(prim.GetAttr("dims")); - } if (inputs.size() == kAnfPopulaterInputNumTwo) { auto inputNode = inputs[kAnfPopulaterInputNumOne]; MS_ASSERT(inputNode != nullptr); @@ -80,6 +74,15 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector &input } } } + if (prim.GetAttr("dims") == nullptr) { + MS_LOG(INFO) << "Tile's attr dims is set to default. The operator in mindspore has no attribute" + "named dims and all the dimensions needs to be multiplied by default."; + for (size_t i = 0; i < attr->multiples.size(); i++) { + attr->dims.push_back(i); + } + } else { + attr->dims = CastToInt(prim.GetAttr("dims")); + } this->primitive_->value.value = attr; } return RET_OK; diff --git a/mindspore/lite/src/ops/topk.cc b/mindspore/lite/src/ops/topk.cc index 24eab142c3..55294d0d27 100644 --- a/mindspore/lite/src/ops/topk.cc +++ b/mindspore/lite/src/ops/topk.cc @@ -28,7 +28,38 @@ bool TopK::GetSorted() const { return this->primitive_->value.AsTopK()->sorted; void TopK::SetK(int k) { this->primitive_->value.AsTopK()->k = k; } void TopK::SetSorted(bool sorted) { this->primitive_->value.AsTopK()->sorted = sorted; } - +int TopK::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_TopK; + } + if (this->primitive_->value.type != schema::PrimitiveType_TopK) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::TopKT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + this->primitive_->value.value = attr; + // the k value of mindspore models is one of inputs instead of an attribute. + attr->k = 0; + if (prim.GetAttr("sorted") != nullptr) { + attr->sorted = GetValue(prim.GetAttr("sorted")); + } + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "primitive value is nullptr"; + return RET_ERROR; + } + } + return RET_OK; +} #else int TopK::GetK() const { return this->primitive_->value_as_TopK()->k(); } @@ -60,7 +91,7 @@ int TopK::InferShape(std::vector inputs_, std::vector output } auto input = inputs_.front(); MS_ASSERT(input != nullptr); - if (input->format() != schema::Format::Format_NHWC) { + if (input->shape().size() == kDimension_4d && input->format() != schema::Format::Format_NHWC) { MS_LOG(ERROR) << "topk only support NHWC now!"; return RET_FORMAT_ERR; } @@ -76,7 +107,16 @@ int TopK::InferShape(std::vector inputs_, std::vector output return RET_INFER_INVALID; } auto out_shape = input->shape(); - out_shape.at(out_shape.size() - 1) = GetK(); + if (inputs_.size() == kSingleNum) { + out_shape.at(out_shape.size() - 1) = GetK(); + } else if (inputs_.size() == kDoubleNum) { + if (inputs_.at(1)->data_c() == nullptr) { + return RET_INFER_INVALID; + } else { + int *data = reinterpret_cast(inputs_.at(1)->data_c()); + out_shape.at(out_shape.size() - 1) = *data; + } + } if (inputs_.size() == kDoubleNum && inputs_.at(1)->data_c() != nullptr) { out_shape.at(out_shape.size() - 1) = reinterpret_cast(inputs_.at(1)->data_c())[0]; } diff --git a/mindspore/lite/src/ops/topk.h b/mindspore/lite/src/ops/topk.h index 2a1cb57ce7..6364002c2e 100644 --- a/mindspore/lite/src/ops/topk.h +++ b/mindspore/lite/src/ops/topk.h @@ -34,6 +34,7 @@ class TopK : public PrimitiveC { explicit TopK(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} void SetK(int k); void SetSorted(bool sorted); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd_fp32.cc index 1759a7817a..6b318a2911 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd_fp32.cc @@ -152,4 +152,5 @@ kernel::LiteKernel *CpuGatherNdFp32KernelCreator(const std::vector