From: @wang_shaocong Reviewed-by: @hangangqiang,@zhanghaibo5 Signed-off-by: @hangangqiangtags/v1.1.0
| @@ -22,7 +22,35 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| int Equal::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_Equal; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_Equal) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| auto attr = new (std::nothrow) schema::EqualT(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.value = attr; | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "primitive value is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int Equal::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != fbb); | |||
| @@ -31,6 +31,7 @@ class Equal : public ArithmeticCompare { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(Equal, ArithmeticCompare); | |||
| explicit Equal(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {} | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| @@ -23,7 +23,36 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| int GatherNd::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_GatherNd; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_GatherNd) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| auto attr = new (std::nothrow) schema::GatherNdT(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| if (prim.GetAttr("batchDims") != nullptr) { | |||
| attr->batchDims = static_cast<int32_t>(GetValue<int64_t>(prim.GetAttr("batchDims"))); | |||
| } | |||
| this->primitive_->value.value = attr; | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "primitive value is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int GatherNd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| @@ -32,6 +32,7 @@ class GatherNd : public PrimitiveC { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(GatherNd, PrimitiveC); | |||
| explicit GatherNd(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| @@ -22,8 +22,35 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| int Greater::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_Greater; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_Greater) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| auto attr = new (std::nothrow) schema::GreaterT(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.value = attr; | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "primitive value is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int Greater::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != fbb); | |||
| @@ -31,6 +31,7 @@ class Greater : public ArithmeticCompare { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(Greater, ArithmeticCompare); | |||
| explicit Greater(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {} | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| @@ -593,6 +593,22 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||
| return NewPrimitiveC<Div>(prim, inputs, quantType); | |||
| } else if (op_type == "Tanh") { | |||
| return NewPrimitiveC<Activation>(prim, inputs, quantType); | |||
| } else if (op_type == "Equal") { | |||
| return NewPrimitiveC<Equal>(prim, inputs, quantType); | |||
| } else if (op_type == "TopK") { | |||
| return NewPrimitiveC<TopK>(prim, inputs, quantType); | |||
| } else if (op_type == "Range") { | |||
| return NewPrimitiveC<Range>(prim, inputs, quantType); | |||
| } else if (op_type == "Tile") { | |||
| return NewPrimitiveC<Tile>(prim, inputs, quantType); | |||
| } else if (op_type == "GatherNd") { | |||
| return NewPrimitiveC<GatherNd>(prim, inputs, quantType); | |||
| } else if (op_type == "Square") { | |||
| return NewPrimitiveC<Square>(prim, inputs, quantType); | |||
| } else if (op_type == "Sqrt") { | |||
| return NewPrimitiveC<Sqrt>(prim, inputs, quantType); | |||
| } else if (op_type == "Greater") { | |||
| return NewPrimitiveC<Greater>(prim, inputs, quantType); | |||
| #ifdef SUPPORT_TRAIN | |||
| } else if (op_type == "SoftmaxCrossEntropyWithLogits") { | |||
| return NewPrimitiveC<SoftmaxCrossEntropy>(prim, inputs, quantType); | |||
| @@ -621,8 +637,6 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||
| return NewPrimitiveC<FlattenGrad>(prim, inputs, quantType); | |||
| } else if (op_type == "FusedBatchNormGrad") { | |||
| return NewPrimitiveC<BNGrad>(prim, inputs, quantType); | |||
| } else if (op_type == "Tile") { | |||
| return NewPrimitiveC<Tile>(prim, inputs, quantType); | |||
| } else if (op_type == "PowerGrad") { | |||
| return NewPrimitiveC<PowerGrad>(prim, inputs, quantType); | |||
| } else if (op_type == "SGD") { | |||
| @@ -14,6 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include <algorithm> | |||
| #include "src/ops/range.h" | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| @@ -32,7 +33,43 @@ void Range::SetDType(int d_type) { this->primitive_->value.AsRange()->dType = d_ | |||
| void Range::SetStart(int start) { this->primitive_->value.AsRange()->start = start; } | |||
| void Range::SetLimit(int limit) { this->primitive_->value.AsRange()->limit = limit; } | |||
| void Range::SetDelta(int delta) { this->primitive_->value.AsRange()->delta = delta; } | |||
| int Range::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_Range; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_Range) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| auto attr = new (std::nothrow) schema::RangeT(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.value = attr; | |||
| attr->dType = 0; | |||
| if (prim.GetAttr("start") != nullptr) { | |||
| attr->start = static_cast<int32_t>(GetValue<float>(prim.GetAttr("start"))); | |||
| } | |||
| if (prim.GetAttr("limit") != nullptr) { | |||
| attr->limit = static_cast<int32_t>(GetValue<float>(prim.GetAttr("limit"))); | |||
| } | |||
| if (prim.GetAttr("delta") != nullptr) { | |||
| attr->delta = static_cast<int32_t>(GetValue<float>(prim.GetAttr("delta"))); | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "primitive value is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int Range::GetDType() const { return this->primitive_->value_as_Range()->dType(); } | |||
| @@ -36,6 +36,7 @@ class Range : public PrimitiveC { | |||
| void SetStart(int start); | |||
| void SetLimit(int limit); | |||
| void SetDelta(int delta); | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| @@ -23,6 +23,33 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| int Sqrt::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_Sqrt; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_Sqrt) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| auto attr = new (std::nothrow) schema::SqrtT(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.value = attr; | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "primitive value is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int Sqrt::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| @@ -32,6 +32,7 @@ class Sqrt : public ArithmeticSelf { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(Sqrt, ArithmeticSelf); | |||
| explicit Sqrt(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| @@ -23,6 +23,33 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| int Square::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_Square; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_Square) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| auto attr = new (std::nothrow) schema::SquareT(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.value = attr; | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "primitive value is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int Square::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| @@ -31,6 +31,7 @@ class Square : public ArithmeticSelf { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(Square, ArithmeticSelf); | |||
| explicit Square(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| @@ -52,12 +52,6 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &input | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| if (prim.GetAttr("dims") == nullptr) { | |||
| MS_LOG(INFO) << "Tile's attr dims is set to default"; | |||
| attr->dims = {1}; | |||
| } else { | |||
| attr->dims = CastToInt(prim.GetAttr("dims")); | |||
| } | |||
| if (inputs.size() == kAnfPopulaterInputNumTwo) { | |||
| auto inputNode = inputs[kAnfPopulaterInputNumOne]; | |||
| MS_ASSERT(inputNode != nullptr); | |||
| @@ -80,6 +74,15 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &input | |||
| } | |||
| } | |||
| } | |||
| if (prim.GetAttr("dims") == nullptr) { | |||
| MS_LOG(INFO) << "Tile's attr dims is set to default. The operator in mindspore has no attribute" | |||
| "named dims and all the dimensions needs to be multiplied by default."; | |||
| for (size_t i = 0; i < attr->multiples.size(); i++) { | |||
| attr->dims.push_back(i); | |||
| } | |||
| } else { | |||
| attr->dims = CastToInt(prim.GetAttr("dims")); | |||
| } | |||
| this->primitive_->value.value = attr; | |||
| } | |||
| return RET_OK; | |||
| @@ -28,7 +28,38 @@ bool TopK::GetSorted() const { return this->primitive_->value.AsTopK()->sorted; | |||
| void TopK::SetK(int k) { this->primitive_->value.AsTopK()->k = k; } | |||
| void TopK::SetSorted(bool sorted) { this->primitive_->value.AsTopK()->sorted = sorted; } | |||
| int TopK::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_TopK; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_TopK) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| auto attr = new (std::nothrow) schema::TopKT(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.value = attr; | |||
| // the k value of mindspore models is one of inputs instead of an attribute. | |||
| attr->k = 0; | |||
| if (prim.GetAttr("sorted") != nullptr) { | |||
| attr->sorted = GetValue<bool>(prim.GetAttr("sorted")); | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "primitive value is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int TopK::GetK() const { return this->primitive_->value_as_TopK()->k(); } | |||
| @@ -60,7 +91,7 @@ int TopK::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output | |||
| } | |||
| auto input = inputs_.front(); | |||
| MS_ASSERT(input != nullptr); | |||
| if (input->format() != schema::Format::Format_NHWC) { | |||
| if (input->shape().size() == kDimension_4d && input->format() != schema::Format::Format_NHWC) { | |||
| MS_LOG(ERROR) << "topk only support NHWC now!"; | |||
| return RET_FORMAT_ERR; | |||
| } | |||
| @@ -76,7 +107,16 @@ int TopK::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output | |||
| return RET_INFER_INVALID; | |||
| } | |||
| auto out_shape = input->shape(); | |||
| out_shape.at(out_shape.size() - 1) = GetK(); | |||
| if (inputs_.size() == kSingleNum) { | |||
| out_shape.at(out_shape.size() - 1) = GetK(); | |||
| } else if (inputs_.size() == kDoubleNum) { | |||
| if (inputs_.at(1)->data_c() == nullptr) { | |||
| return RET_INFER_INVALID; | |||
| } else { | |||
| int *data = reinterpret_cast<int32_t *>(inputs_.at(1)->data_c()); | |||
| out_shape.at(out_shape.size() - 1) = *data; | |||
| } | |||
| } | |||
| if (inputs_.size() == kDoubleNum && inputs_.at(1)->data_c() != nullptr) { | |||
| out_shape.at(out_shape.size() - 1) = reinterpret_cast<int *>(inputs_.at(1)->data_c())[0]; | |||
| } | |||
| @@ -34,6 +34,7 @@ class TopK : public PrimitiveC { | |||
| explicit TopK(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| void SetK(int k); | |||
| void SetSorted(bool sorted); | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| @@ -152,4 +152,5 @@ kernel::LiteKernel *CpuGatherNdFp32KernelCreator(const std::vector<lite::Tensor | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_GatherNd, CpuGatherNdFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_GatherNd, CpuGatherNdFp32KernelCreator) | |||
| } // namespace mindspore::kernel | |||