Browse Source

!10013 [MSLITE] Fix bug of converter for mindspore models.

From: @wang_shaocong
Reviewed-by: @hangangqiang,@zhanghaibo5
Signed-off-by: @hangangqiang
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
9ed5168c90
17 changed files with 256 additions and 16 deletions
  1. +29
    -1
      mindspore/lite/src/ops/equal.cc
  2. +1
    -0
      mindspore/lite/src/ops/equal.h
  3. +30
    -1
      mindspore/lite/src/ops/gather_nd.cc
  4. +1
    -0
      mindspore/lite/src/ops/gather_nd.h
  5. +29
    -2
      mindspore/lite/src/ops/greater.cc
  6. +1
    -0
      mindspore/lite/src/ops/greater.h
  7. +16
    -2
      mindspore/lite/src/ops/primitive_c.cc
  8. +38
    -1
      mindspore/lite/src/ops/range.cc
  9. +1
    -0
      mindspore/lite/src/ops/range.h
  10. +27
    -0
      mindspore/lite/src/ops/sqrt.cc
  11. +1
    -0
      mindspore/lite/src/ops/sqrt.h
  12. +27
    -0
      mindspore/lite/src/ops/square.cc
  13. +1
    -0
      mindspore/lite/src/ops/square.h
  14. +9
    -6
      mindspore/lite/src/ops/tile.cc
  15. +43
    -3
      mindspore/lite/src/ops/topk.cc
  16. +1
    -0
      mindspore/lite/src/ops/topk.h
  17. +1
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd_fp32.cc

+ 29
- 1
mindspore/lite/src/ops/equal.cc View File

@@ -22,7 +22,35 @@

namespace mindspore {
namespace lite {
#ifndef PRIMITIVE_WRITEABLE
#ifdef PRIMITIVE_WRITEABLE
int Equal::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Equal;
}
if (this->primitive_->value.type != schema::PrimitiveType_Equal) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::EqualT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#else
int Equal::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);


+ 1
- 0
mindspore/lite/src/ops/equal.h View File

@@ -31,6 +31,7 @@ class Equal : public ArithmeticCompare {
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Equal, ArithmeticCompare);
explicit Equal(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif


+ 30
- 1
mindspore/lite/src/ops/gather_nd.cc View File

@@ -23,7 +23,36 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE

int GatherNd::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_GatherNd;
}
if (this->primitive_->value.type != schema::PrimitiveType_GatherNd) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::GatherNdT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
if (prim.GetAttr("batchDims") != nullptr) {
attr->batchDims = static_cast<int32_t>(GetValue<int64_t>(prim.GetAttr("batchDims")));
}
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#else
int GatherNd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);


+ 1
- 0
mindspore/lite/src/ops/gather_nd.h View File

@@ -32,6 +32,7 @@ class GatherNd : public PrimitiveC {
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(GatherNd, PrimitiveC);
explicit GatherNd(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif


+ 29
- 2
mindspore/lite/src/ops/greater.cc View File

@@ -22,8 +22,35 @@

namespace mindspore {
namespace lite {
#ifndef PRIMITIVE_WRITEABLE

#ifdef PRIMITIVE_WRITEABLE
int Greater::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Greater;
}
if (this->primitive_->value.type != schema::PrimitiveType_Greater) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::GreaterT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#else
int Greater::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);


+ 1
- 0
mindspore/lite/src/ops/greater.h View File

@@ -31,6 +31,7 @@ class Greater : public ArithmeticCompare {
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Greater, ArithmeticCompare);
explicit Greater(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif


+ 16
- 2
mindspore/lite/src/ops/primitive_c.cc View File

@@ -593,6 +593,22 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<Div>(prim, inputs, quantType);
} else if (op_type == "Tanh") {
return NewPrimitiveC<Activation>(prim, inputs, quantType);
} else if (op_type == "Equal") {
return NewPrimitiveC<Equal>(prim, inputs, quantType);
} else if (op_type == "TopK") {
return NewPrimitiveC<TopK>(prim, inputs, quantType);
} else if (op_type == "Range") {
return NewPrimitiveC<Range>(prim, inputs, quantType);
} else if (op_type == "Tile") {
return NewPrimitiveC<Tile>(prim, inputs, quantType);
} else if (op_type == "GatherNd") {
return NewPrimitiveC<GatherNd>(prim, inputs, quantType);
} else if (op_type == "Square") {
return NewPrimitiveC<Square>(prim, inputs, quantType);
} else if (op_type == "Sqrt") {
return NewPrimitiveC<Sqrt>(prim, inputs, quantType);
} else if (op_type == "Greater") {
return NewPrimitiveC<Greater>(prim, inputs, quantType);
#ifdef SUPPORT_TRAIN
} else if (op_type == "SoftmaxCrossEntropyWithLogits") {
return NewPrimitiveC<SoftmaxCrossEntropy>(prim, inputs, quantType);
@@ -621,8 +637,6 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<FlattenGrad>(prim, inputs, quantType);
} else if (op_type == "FusedBatchNormGrad") {
return NewPrimitiveC<BNGrad>(prim, inputs, quantType);
} else if (op_type == "Tile") {
return NewPrimitiveC<Tile>(prim, inputs, quantType);
} else if (op_type == "PowerGrad") {
return NewPrimitiveC<PowerGrad>(prim, inputs, quantType);
} else if (op_type == "SGD") {


+ 38
- 1
mindspore/lite/src/ops/range.cc View File

@@ -14,6 +14,7 @@
* limitations under the License.
*/

#include <algorithm>
#include "src/ops/range.h"

#ifndef PRIMITIVE_WRITEABLE
@@ -32,7 +33,43 @@ void Range::SetDType(int d_type) { this->primitive_->value.AsRange()->dType = d_
void Range::SetStart(int start) { this->primitive_->value.AsRange()->start = start; }
void Range::SetLimit(int limit) { this->primitive_->value.AsRange()->limit = limit; }
void Range::SetDelta(int delta) { this->primitive_->value.AsRange()->delta = delta; }

int Range::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Range;
}
if (this->primitive_->value.type != schema::PrimitiveType_Range) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::RangeT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
this->primitive_->value.value = attr;
attr->dType = 0;
if (prim.GetAttr("start") != nullptr) {
attr->start = static_cast<int32_t>(GetValue<float>(prim.GetAttr("start")));
}
if (prim.GetAttr("limit") != nullptr) {
attr->limit = static_cast<int32_t>(GetValue<float>(prim.GetAttr("limit")));
}
if (prim.GetAttr("delta") != nullptr) {
attr->delta = static_cast<int32_t>(GetValue<float>(prim.GetAttr("delta")));
}
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#else

int Range::GetDType() const { return this->primitive_->value_as_Range()->dType(); }


+ 1
- 0
mindspore/lite/src/ops/range.h View File

@@ -36,6 +36,7 @@ class Range : public PrimitiveC {
void SetStart(int start);
void SetLimit(int limit);
void SetDelta(int delta);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif


+ 27
- 0
mindspore/lite/src/ops/sqrt.cc View File

@@ -23,6 +23,33 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Sqrt::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Sqrt;
}
if (this->primitive_->value.type != schema::PrimitiveType_Sqrt) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::SqrtT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#else
int Sqrt::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);


+ 1
- 0
mindspore/lite/src/ops/sqrt.h View File

@@ -32,6 +32,7 @@ class Sqrt : public ArithmeticSelf {
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Sqrt, ArithmeticSelf);
explicit Sqrt(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif


+ 27
- 0
mindspore/lite/src/ops/square.cc View File

@@ -23,6 +23,33 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Square::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Square;
}
if (this->primitive_->value.type != schema::PrimitiveType_Square) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::SquareT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#else
int Square::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);


+ 1
- 0
mindspore/lite/src/ops/square.h View File

@@ -31,6 +31,7 @@ class Square : public ArithmeticSelf {
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Square, ArithmeticSelf);
explicit Square(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif


+ 9
- 6
mindspore/lite/src/ops/tile.cc View File

@@ -52,12 +52,6 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &input
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
if (prim.GetAttr("dims") == nullptr) {
MS_LOG(INFO) << "Tile's attr dims is set to default";
attr->dims = {1};
} else {
attr->dims = CastToInt(prim.GetAttr("dims"));
}
if (inputs.size() == kAnfPopulaterInputNumTwo) {
auto inputNode = inputs[kAnfPopulaterInputNumOne];
MS_ASSERT(inputNode != nullptr);
@@ -80,6 +74,15 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &input
}
}
}
if (prim.GetAttr("dims") == nullptr) {
MS_LOG(INFO) << "Tile's attr dims is set to default. The operator in mindspore has no attribute"
"named dims and all the dimensions needs to be multiplied by default.";
for (size_t i = 0; i < attr->multiples.size(); i++) {
attr->dims.push_back(i);
}
} else {
attr->dims = CastToInt(prim.GetAttr("dims"));
}
this->primitive_->value.value = attr;
}
return RET_OK;


+ 43
- 3
mindspore/lite/src/ops/topk.cc View File

@@ -28,7 +28,38 @@ bool TopK::GetSorted() const { return this->primitive_->value.AsTopK()->sorted;

void TopK::SetK(int k) { this->primitive_->value.AsTopK()->k = k; }
void TopK::SetSorted(bool sorted) { this->primitive_->value.AsTopK()->sorted = sorted; }

int TopK::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_TopK;
}
if (this->primitive_->value.type != schema::PrimitiveType_TopK) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::TopKT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
this->primitive_->value.value = attr;
// the k value of mindspore models is one of inputs instead of an attribute.
attr->k = 0;
if (prim.GetAttr("sorted") != nullptr) {
attr->sorted = GetValue<bool>(prim.GetAttr("sorted"));
}
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#else

int TopK::GetK() const { return this->primitive_->value_as_TopK()->k(); }
@@ -60,7 +91,7 @@ int TopK::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
}
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
if (input->format() != schema::Format::Format_NHWC) {
if (input->shape().size() == kDimension_4d && input->format() != schema::Format::Format_NHWC) {
MS_LOG(ERROR) << "topk only support NHWC now!";
return RET_FORMAT_ERR;
}
@@ -76,7 +107,16 @@ int TopK::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
return RET_INFER_INVALID;
}
auto out_shape = input->shape();
out_shape.at(out_shape.size() - 1) = GetK();
if (inputs_.size() == kSingleNum) {
out_shape.at(out_shape.size() - 1) = GetK();
} else if (inputs_.size() == kDoubleNum) {
if (inputs_.at(1)->data_c() == nullptr) {
return RET_INFER_INVALID;
} else {
int *data = reinterpret_cast<int32_t *>(inputs_.at(1)->data_c());
out_shape.at(out_shape.size() - 1) = *data;
}
}
if (inputs_.size() == kDoubleNum && inputs_.at(1)->data_c() != nullptr) {
out_shape.at(out_shape.size() - 1) = reinterpret_cast<int *>(inputs_.at(1)->data_c())[0];
}


+ 1
- 0
mindspore/lite/src/ops/topk.h View File

@@ -34,6 +34,7 @@ class TopK : public PrimitiveC {
explicit TopK(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetK(int k);
void SetSorted(bool sorted);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif


+ 1
- 0
mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd_fp32.cc View File

@@ -152,4 +152,5 @@ kernel::LiteKernel *CpuGatherNdFp32KernelCreator(const std::vector<lite::Tensor
}

REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_GatherNd, CpuGatherNdFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_GatherNd, CpuGatherNdFp32KernelCreator)
} // namespace mindspore::kernel

Loading…
Cancel
Save