Browse Source

!6286 [MSLITE] Add converter method for operator 'Cast'.

Merge pull request !6286 from wangshaocong/lite_convert
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
c0f3e82f9e
3 changed files with 36 additions and 2 deletions
  1. +33
    -0
      mindspore/lite/src/ops/cast.cc
  2. +1
    -0
      mindspore/lite/src/ops/cast.h
  3. +2
    -2
      mindspore/lite/src/ops/primitive_c.cc

+ 33
- 0
mindspore/lite/src/ops/cast.cc View File

@@ -25,6 +25,39 @@ int Cast::GetDstT() const { return this->primitive_->value.AsCast()->dstT; }
void Cast::SetSrcT(int src_t) { this->primitive_->value.AsCast()->srcT = src_t; } void Cast::SetSrcT(int src_t) { this->primitive_->value.AsCast()->srcT = src_t; }
void Cast::SetDstT(int dst_t) { this->primitive_->value.AsCast()->dstT = dst_t; } void Cast::SetDstT(int dst_t) { this->primitive_->value.AsCast()->dstT = dst_t; }


int Cast::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Cast;
}
if (this->primitive_->value.type != schema::PrimitiveType_Cast) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::CastT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
auto srcAnf = reinterpret_cast<mindspore::Number *>(prim.GetAttr("SrcT").get());
auto dstAnf = reinterpret_cast<mindspore::Number *>(prim.GetAttr("DstT").get());
attr->srcT = srcAnf->number_type();
attr->dstT = dstAnf->number_type();
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}

return RET_OK;
}

#else #else
int Cast::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { int Cast::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive); MS_ASSERT(nullptr != primitive);


+ 1
- 0
mindspore/lite/src/ops/cast.h View File

@@ -33,6 +33,7 @@ class Cast : public PrimitiveC {
explicit Cast(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} explicit Cast(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetSrcT(int src_t); void SetSrcT(int src_t);
void SetDstT(int dst_t); void SetDstT(int dst_t);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
#else #else
Cast() = default; Cast() = default;




+ 2
- 2
mindspore/lite/src/ops/primitive_c.cc View File

@@ -399,8 +399,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<SoftMax>(prim, inputs, quantType); return NewPrimitiveC<SoftMax>(prim, inputs, quantType);
} else if (op_type == "StridedSlice") { } else if (op_type == "StridedSlice") {
return NewPrimitiveC<StridedSlice>(prim, inputs, quantType); return NewPrimitiveC<StridedSlice>(prim, inputs, quantType);
} else if (op_type == "AvgPool") {
return NewPrimitiveC<Pooling>(prim, inputs, quantType);
} else if (op_type == "Cast") {
return NewPrimitiveC<Cast>(prim, inputs, quantType);




#ifdef SUPPORT_TRAIN #ifdef SUPPORT_TRAIN


Loading…
Cancel
Save